Compare commits
113 Commits
56ecc73d1b
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 6b00045f4e | |||
| e719d3c8f1 | |||
| 6908b9653b | |||
| d8e64ef0e9 | |||
| fb9f6d1d00 | |||
| e4c96da8a9 | |||
| 1195e119c6 | |||
| 4eeb14e844 | |||
| 0d30ed9a0f | |||
| cd0b3e8fc1 | |||
| c04a13bf8a | |||
| 5513f0c13d | |||
| 598e2acb7e | |||
| 4870d29638 | |||
| 8600a39b6c | |||
| 407d008e19 | |||
| a2751eab31 | |||
| 5655fc5560 | |||
| 49b47a1ae0 | |||
| bcf82d42bc | |||
| 394025c8ea | |||
| 34bd749741 | |||
| 290f299e22 | |||
| 859dec8ada | |||
| 993c0a72d6 | |||
| c9c3a84b33 | |||
| 6de7a2d2e1 | |||
| 6181923d8d | |||
| 235efb0e62 | |||
| 6b1af27ea2 | |||
| 32f48777f3 | |||
| bc7a7c6e81 | |||
| 3cd0458c2c | |||
| 8eea30ea11 | |||
| 9e33e570af | |||
| 7653385838 | |||
| 2c401f7ae6 | |||
| a9972360c2 | |||
| b00fa4dcee | |||
| 92525b39c3 | |||
| 38a2555c7b | |||
| 9622d44aac | |||
| 155244433f | |||
| 2c043c6cf7 | |||
| f5c82b6980 | |||
| 9105a36097 | |||
| f1ee646ca4 | |||
| b9b487c591 | |||
| 4c62c071fb | |||
| b2e9dd8b7f | |||
| d143c5f3df | |||
| 4eebdfb8db | |||
| b517946585 | |||
| 4ddae6be74 | |||
| 195762ff97 | |||
| bcf5ca89e5 | |||
| 365943e4c4 | |||
| 4c6b49099d | |||
| 4c78ab6cc8 | |||
| 52007c9461 | |||
| 086dd1fed7 | |||
| 1d7e839b49 | |||
| fa7babf13b | |||
| 280099b89c | |||
| 0a92a25451 | |||
| 8c075194e5 | |||
| 53e477d383 | |||
| 1522c87c74 | |||
| e0d05c9869 | |||
| 5b401e29cb | |||
| 65ac9f740a | |||
| 58ebcaa299 | |||
| 5b765c8b5e | |||
| b3258e76df | |||
| 64dc66afa6 | |||
| 15f08ee2ca | |||
| 380586afa6 | |||
| ebb70809bf | |||
| 7399afbc5c | |||
| c0669e4b07 | |||
| 05c04091b3 | |||
| 0b05e08705 | |||
| df253559a5 | |||
| 669cbb8c51 | |||
| 5ae9d85272 | |||
| 72aebef625 | |||
| f5e45d032e | |||
| b03e5f809f | |||
| ec563aaa16 | |||
| 873f09d3bf | |||
| 5e7267db07 | |||
| 7b28cee7a1 | |||
| 934c8dea77 | |||
| 7d91fe345e | |||
| 4e86adffb7 | |||
| 5d58acf5a6 | |||
| 81dcecb723 | |||
| 141f5f886f | |||
| 7fa5af483b | |||
| f488b9cc15 | |||
| 59179094ed | |||
| 4fc5fb4764 | |||
| feff97acbd | |||
| b7e205f4b6 | |||
| 24f03595a7 | |||
| 395887667d | |||
| 44d6af026a | |||
| 6e11ada42c | |||
| da790db75b | |||
| e1af978c56 | |||
| 980875ecf3 | |||
| 7f0f831226 | |||
| f3a207fa16 |
8
.claude/settings.json
Normal file
8
.claude/settings.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"tdesign-mcp-server": {
|
||||
"command": "bunx",
|
||||
"args": ["tdesign-mcp-server@latest"]
|
||||
}
|
||||
}
|
||||
}
|
||||
7
.editorconfig
Normal file
7
.editorconfig
Normal file
@@ -0,0 +1,7 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
end_of_line = lf
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
charset = utf-8
|
||||
9
.gitattributes
vendored
Normal file
9
.gitattributes
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
* text=auto eol=lf
|
||||
assets/*.png filter=lfs diff=lfs merge=lfs -text
|
||||
assets/**/*.png filter=lfs diff=lfs merge=lfs -text
|
||||
assets/*.icns filter=lfs diff=lfs merge=lfs -text
|
||||
assets/**/*.icns filter=lfs diff=lfs merge=lfs -text
|
||||
assets/*.ico filter=lfs diff=lfs merge=lfs -text
|
||||
assets/**/*.ico filter=lfs diff=lfs merge=lfs -text
|
||||
frontend/public/*.png filter=lfs diff=lfs merge=lfs -text
|
||||
frontend/public/**/*.png filter=lfs diff=lfs merge=lfs -text
|
||||
14
.github/workflows/ci.yml
vendored
Normal file
14
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [dev, master]
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
check:
|
||||
name: Check
|
||||
uses: ./.github/workflows/test.yml
|
||||
313
.github/workflows/release.yml
vendored
Normal file
313
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,313 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
prepare:
|
||||
name: Prepare Release
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
outputs:
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.work
|
||||
cache-dependency-path: |
|
||||
backend/go.sum
|
||||
versionctl/go.sum
|
||||
|
||||
- name: Verify tag and VERSION
|
||||
id: version
|
||||
run: |
|
||||
version=$(go run ./versionctl print)
|
||||
go run ./versionctl verify-tag "${GITHUB_REF_NAME}"
|
||||
printf 'version=%s\n' "$version" >> "$GITHUB_OUTPUT"
|
||||
|
||||
test-gate:
|
||||
name: Test Gate
|
||||
needs: prepare
|
||||
uses: ./.github/workflows/test.yml
|
||||
with:
|
||||
full: true
|
||||
|
||||
build-web:
|
||||
name: Build Web Asset
|
||||
needs: [prepare, test-gate]
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.work
|
||||
cache-dependency-path: |
|
||||
backend/go.sum
|
||||
versionctl/go.sum
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
|
||||
- name: Preflight web release toolchain
|
||||
run: |
|
||||
set -euo pipefail
|
||||
command -v go
|
||||
go version
|
||||
command -v bun
|
||||
bun --version
|
||||
make release-assets-check
|
||||
|
||||
- name: Build web release asset
|
||||
run: make release-assets-web
|
||||
|
||||
- name: Upload web release asset
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-web
|
||||
path: build/release/*
|
||||
if-no-files-found: error
|
||||
|
||||
build-linux:
|
||||
name: Build Linux ${{ matrix.arch }} Assets
|
||||
needs: [prepare, test-gate]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- arch: amd64
|
||||
runner: ubuntu-latest
|
||||
- arch: arm64
|
||||
runner: ubuntu-24.04-arm
|
||||
runs-on: ${{ matrix.runner }}
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.work
|
||||
cache-dependency-path: |
|
||||
backend/go.sum
|
||||
versionctl/go.sum
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
|
||||
- name: Install Linux desktop and package dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y curl file libayatana-appindicator3-dev libgtk-3-dev rpm
|
||||
|
||||
- name: Preflight Linux release toolchain
|
||||
run: |
|
||||
set -euo pipefail
|
||||
printf 'runner arch: %s\n' "$(uname -m)"
|
||||
command -v go
|
||||
go version
|
||||
command -v bun
|
||||
bun --version
|
||||
command -v gcc
|
||||
gcc --version
|
||||
command -v pkg-config
|
||||
pkg-config --modversion ayatana-appindicator3-0.1
|
||||
pkg-config --modversion gtk+-3.0
|
||||
command -v curl
|
||||
command -v dpkg-deb
|
||||
dpkg-deb --version
|
||||
command -v rpmbuild
|
||||
rpmbuild --version
|
||||
make release-assets-check
|
||||
|
||||
- name: Build Linux release assets
|
||||
run: make release-assets-linux
|
||||
|
||||
- name: Upload Linux release assets
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-linux-${{ matrix.arch }}
|
||||
path: build/release/*
|
||||
if-no-files-found: error
|
||||
|
||||
build-windows:
|
||||
name: Build Windows ${{ matrix.arch }} Assets
|
||||
needs: [prepare, test-gate]
|
||||
runs-on: ${{ matrix.runner }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- arch: amd64
|
||||
runner: windows-latest
|
||||
msystem: MINGW64
|
||||
cc: gcc
|
||||
cxx: g++
|
||||
packages: >-
|
||||
make
|
||||
mingw-w64-x86_64-gcc
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.work
|
||||
cache-dependency-path: |
|
||||
backend/go.sum
|
||||
versionctl/go.sum
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
|
||||
- name: Setup MSYS2 toolchain
|
||||
uses: msys2/setup-msys2@v2
|
||||
with:
|
||||
msystem: ${{ matrix.msystem }}
|
||||
path-type: inherit
|
||||
update: true
|
||||
install: ${{ matrix.packages }}
|
||||
|
||||
- name: Preflight Windows release toolchain
|
||||
shell: msys2 {0}
|
||||
env:
|
||||
CC: ${{ matrix.cc }}
|
||||
CXX: ${{ matrix.cxx }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
command -v go
|
||||
go version
|
||||
command -v bun
|
||||
bun --version
|
||||
command -v make
|
||||
make --version
|
||||
command -v "$CC"
|
||||
"$CC" --version
|
||||
command -v "$CXX"
|
||||
"$CXX" --version
|
||||
command -v windres
|
||||
windres --version
|
||||
if command -v powershell.exe >/dev/null 2>&1; then
|
||||
powershell.exe -NoProfile -Command '$PSVersionTable.PSVersion.ToString()'
|
||||
else
|
||||
command -v powershell
|
||||
powershell -NoProfile -Command '$PSVersionTable.PSVersion.ToString()'
|
||||
fi
|
||||
make release-assets-check
|
||||
|
||||
- name: Build Windows release assets
|
||||
shell: msys2 {0}
|
||||
env:
|
||||
CC: ${{ matrix.cc }}
|
||||
CXX: ${{ matrix.cxx }}
|
||||
run: make release-assets-windows
|
||||
|
||||
- name: Upload Windows release assets
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-windows-${{ matrix.arch }}
|
||||
path: build/release/*
|
||||
if-no-files-found: error
|
||||
|
||||
build-macos:
|
||||
name: Build macOS Assets
|
||||
needs: [prepare, test-gate]
|
||||
runs-on: macos-15
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.work
|
||||
cache-dependency-path: |
|
||||
backend/go.sum
|
||||
versionctl/go.sum
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
|
||||
- name: Preflight macOS release toolchain
|
||||
run: |
|
||||
set -euo pipefail
|
||||
printf 'runner arch: %s\n' "$(uname -m)"
|
||||
command -v go
|
||||
go version
|
||||
command -v bun
|
||||
bun --version
|
||||
command -v ditto
|
||||
command -v hdiutil
|
||||
xcrun --find lipo
|
||||
xcrun --find vtool
|
||||
make release-assets-check
|
||||
|
||||
- name: Build macOS release assets
|
||||
run: make release-assets-macos
|
||||
|
||||
- name: Upload macOS release assets
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-macos
|
||||
path: build/release/*
|
||||
if-no-files-found: error
|
||||
|
||||
draft-release:
|
||||
name: Create Draft Release
|
||||
needs: [prepare, build-web, build-linux, build-windows, build-macos]
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Download release assets
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: release-*
|
||||
merge-multiple: true
|
||||
path: dist
|
||||
|
||||
- name: Generate checksums
|
||||
run: make release-assets-checksums RELEASE_DIR=dist
|
||||
|
||||
- name: Publish draft release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
name: v${{ needs.prepare.outputs.version }}
|
||||
tag_name: ${{ github.ref_name }}
|
||||
draft: true
|
||||
files: |
|
||||
dist/*
|
||||
116
.github/workflows/test.yml
vendored
Normal file
116
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,116 @@
|
||||
name: Test (Full)
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
full:
|
||||
description: "Run full test suite including MySQL and E2E"
|
||||
required: false
|
||||
default: false
|
||||
type: boolean
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
check:
|
||||
name: Check (${{ matrix.os }})
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Install Linux system dependencies
|
||||
if: runner.os == 'Linux'
|
||||
run: sudo apt-get update && sudo apt-get install -y libayatana-appindicator3-dev
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.work
|
||||
cache-dependency-path: |
|
||||
backend/go.sum
|
||||
versionctl/go.sum
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
|
||||
- name: Lint
|
||||
run: make lint
|
||||
|
||||
- name: Test
|
||||
run: make test
|
||||
|
||||
mysql:
|
||||
name: MySQL Tests
|
||||
if: inputs.full
|
||||
needs: check
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
mysql:
|
||||
image: mysql:8.0
|
||||
env:
|
||||
MYSQL_ROOT_PASSWORD: testpass
|
||||
MYSQL_DATABASE: nex_test
|
||||
MYSQL_USER: nex_test
|
||||
MYSQL_PASSWORD: testpass
|
||||
ports:
|
||||
- 13306:3306
|
||||
options: >-
|
||||
--health-cmd="mysqladmin ping -h localhost -u root -ptestpass"
|
||||
--health-interval=3s
|
||||
--health-timeout=5s
|
||||
--health-retries=10
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.work
|
||||
cache-dependency-path: |
|
||||
backend/go.sum
|
||||
versionctl/go.sum
|
||||
|
||||
- name: MySQL tests
|
||||
run: cd backend && go test -tags=mysql ./tests/mysql/... -v -count=1
|
||||
|
||||
e2e:
|
||||
name: E2E Tests
|
||||
if: inputs.full
|
||||
needs: check
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.work
|
||||
cache-dependency-path: |
|
||||
backend/go.sum
|
||||
versionctl/go.sum
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
|
||||
- name: Install Playwright browsers
|
||||
run: cd frontend && bunx playwright install --with-deps chromium
|
||||
|
||||
- name: E2E tests
|
||||
run: cd frontend && bun run test:e2e
|
||||
16
.gitignore
vendored
16
.gitignore
vendored
@@ -399,9 +399,21 @@ env/
|
||||
cython_debug/
|
||||
|
||||
# Custom
|
||||
.claude
|
||||
.claude/*
|
||||
!.claude/settings.json
|
||||
.opencode
|
||||
.codex
|
||||
openspec/changes/archive
|
||||
temp
|
||||
.agents
|
||||
skills-lock.json
|
||||
skills-lock.json
|
||||
.worktrees
|
||||
!scripts/build/
|
||||
backend/bin
|
||||
backend/server
|
||||
backend/desktop
|
||||
|
||||
# Embedfs generated
|
||||
embedfs/assets/
|
||||
embedfs/frontend-dist/
|
||||
backend/cmd/desktop/rsrc_windows_*.syso
|
||||
|
||||
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"files.eol": "\n"
|
||||
}
|
||||
184
LICENSE
Normal file
184
LICENSE
Normal file
@@ -0,0 +1,184 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction, and
|
||||
distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by the copyright
|
||||
owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all other entities
|
||||
that control, are controlled by, or are under common control with that entity.
|
||||
For the purposes of this definition, "control" means (i) the power, direct or
|
||||
indirect, to cause the direction or management of such entity, whether by
|
||||
contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity exercising
|
||||
permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications, including
|
||||
but not limited to software source code, documentation source, and configuration
|
||||
files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical transformation or
|
||||
translation of a Source form, including but not limited to compiled object code,
|
||||
generated documentation, and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or Object form,
|
||||
made available under the License, as indicated by a copyright notice that is
|
||||
included in or attached to the work (an example is provided in the Appendix
|
||||
below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object form, that
|
||||
is based on (or derived from) the Work and for which the editorial revisions,
|
||||
annotations, elaborations, or other modifications represent, as a whole, an
|
||||
original work of authorship. For the purposes of this License, Derivative Works
|
||||
shall not include works that remain separable from, or merely link (or bind by
|
||||
name) to the interfaces of, the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including the original version
|
||||
of the Work and any modifications or additions to that Work or Derivative Works
|
||||
thereof, that is intentionally submitted to Licensor for inclusion in the Work
|
||||
by the copyright owner or by an individual or Legal Entity authorized to submit
|
||||
on behalf of the copyright owner. For the purposes of this definition,
|
||||
"submitted" means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems, and
|
||||
issue tracking systems that are managed by, or on behalf of, the Licensor for
|
||||
the purpose of discussing and improving the Work, but excluding communication
|
||||
that is conspicuously marked or otherwise designated in writing by the copyright
|
||||
owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf
|
||||
of whom a Contribution has been received by Licensor and subsequently
|
||||
incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this
|
||||
License, each Contributor hereby grants to You a perpetual, worldwide,
|
||||
non-exclusive, no-charge, royalty-free, irrevocable copyright license to
|
||||
reproduce, prepare Derivative Works of, publicly display, publicly perform,
|
||||
sublicense, and distribute the Work and such Derivative Works in Source or
|
||||
Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this License,
|
||||
each Contributor hereby grants to You a perpetual, worldwide, non-exclusive,
|
||||
no-charge, royalty-free, irrevocable (except as stated in this section) patent
|
||||
license to make, have made, use, offer to sell, sell, import, and otherwise
|
||||
transfer the Work, where such license applies only to those patent claims
|
||||
licensable by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s) with the Work
|
||||
to which such Contribution(s) was submitted. If You institute patent litigation
|
||||
against any entity (including a cross-claim or counterclaim in a lawsuit)
|
||||
alleging that the Work or a Contribution incorporated within the Work
|
||||
constitutes direct or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate as of the date
|
||||
such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the Work or
|
||||
Derivative Works thereof in any medium, with or without modifications, and in
|
||||
Source or Object form, provided that You meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or Derivative Works a copy of
|
||||
this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices stating that
|
||||
You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works that You
|
||||
distribute, all copyright, patent, trademark, and attribution notices from the
|
||||
Source form of the Work, excluding those notices that do not pertain to any part
|
||||
of the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its distribution, then
|
||||
any Derivative Works that You distribute must include a readable copy of the
|
||||
attribution notices contained within such NOTICE file, excluding those notices
|
||||
that do not pertain to any part of the Derivative Works, in at least one of the
|
||||
following places: within a NOTICE text file distributed as part of the
|
||||
Derivative Works; within the Source form or documentation, if provided along
|
||||
with the Derivative Works; or, within a display generated by the Derivative
|
||||
Works, if and wherever such third-party notices normally appear. The contents of
|
||||
the NOTICE file are for informational purposes only and do not modify the
|
||||
License. You may add Your own attribution notices within Derivative Works that
|
||||
You distribute, alongside or as an addendum to the NOTICE text from the Work,
|
||||
provided that such additional attribution notices cannot be construed as
|
||||
modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and may provide
|
||||
additional or different license terms and conditions for use, reproduction, or
|
||||
distribution of Your modifications, or for any such Derivative Works as a whole,
|
||||
provided Your use, reproduction, and distribution of the Work otherwise complies
|
||||
with the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise, any
|
||||
Contribution intentionally submitted for inclusion in the Work by You to the
|
||||
Licensor shall be under the terms and conditions of this License, without any
|
||||
additional terms or conditions. Notwithstanding the above, nothing herein shall
|
||||
supersede or modify the terms of any separate license agreement you may have
|
||||
executed with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade names,
|
||||
trademarks, service marks, or product names of the Licensor, except as required
|
||||
for reasonable and customary use in describing the origin of the Work and
|
||||
reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in
|
||||
writing, Licensor provides the Work (and each Contributor provides its
|
||||
Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied, including, without limitation, any warranties
|
||||
or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any risks
|
||||
associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory, whether in
|
||||
tort (including negligence), contract, or otherwise, unless required by
|
||||
applicable law (such as deliberate and grossly negligent acts) or agreed to in
|
||||
writing, shall any Contributor be liable to You for damages, including any
|
||||
direct, indirect, special, incidental, or consequential damages of any character
|
||||
arising as a result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill, work stoppage,
|
||||
computer failure or malfunction, or any and all other commercial damages or
|
||||
losses), even if such Contributor has been advised of the possibility of such
|
||||
damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing the Work or
|
||||
Derivative Works thereof, You may choose to offer, and charge a fee for,
|
||||
acceptance of support, warranty, indemnity, or other liability obligations
|
||||
and/or rights consistent with this License. However, in accepting such
|
||||
obligations, You may act only on Your own behalf and on Your sole
|
||||
responsibility, not on behalf of any other Contributor, and only if You agree to
|
||||
indemnify, defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason of your
|
||||
accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following boilerplate
|
||||
notice, with the fields enclosed by brackets "[]" replaced with your own
|
||||
identifying information. (Don't include the brackets!) The text should be
|
||||
enclosed in the appropriate comment syntax for the file format. We also
|
||||
recommend that a file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier identification within
|
||||
third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
520
Makefile
Normal file
520
Makefile
Normal file
@@ -0,0 +1,520 @@
|
||||
.PHONY: \
|
||||
lint test clean hooks-install hooks-check hooks-test \
|
||||
version-sync version-check version-bump \
|
||||
server-run server-build server-lint server-test server-clean \
|
||||
desktop-build-mac desktop-build-win desktop-build-linux \
|
||||
desktop-lint desktop-test desktop-clean \
|
||||
release-assets-check release-assets-web release-assets-linux release-assets-windows release-assets-macos release-assets-checksums \
|
||||
release-assets-server-linux release-assets-server-windows release-assets-server-macos \
|
||||
release-assets-desktop-linux release-assets-desktop-windows release-assets-desktop-macos \
|
||||
_backend-lint _backend-test _backend-clean _backend-build \
|
||||
_versionctl-lint _versionctl-test \
|
||||
_frontend-install _frontend-build _frontend-check _frontend-test _frontend-dev _frontend-clean \
|
||||
_hooks-pre-commit _check-clean-worktree \
|
||||
_desktop-test _desktop-clean _desktop-prepare-frontend _desktop-prepare-embedfs _desktop-prepare-windows-resource \
|
||||
_server-run-backend _server-run-frontend \
|
||||
_check-linux-target-arch _check-windows-target-arch _ensure-appimagetool \
|
||||
_package-linux-tar _package-linux-appimage _package-linux-deb _package-linux-rpm \
|
||||
_package-macos-zip _package-macos-dmg
|
||||
|
||||
# Delay shell lookups until a target needs them, then cache the result for this make run.
|
||||
lazy_shell = $(or $($(1)),$(eval $(1) := $(shell $(2)))$($(1)))
|
||||
|
||||
VERSION = $(call lazy_shell,_VERSION,go run ./versionctl print)
|
||||
GIT_COMMIT ?= $(call lazy_shell,_GIT_COMMIT,git rev-parse --short HEAD 2>/dev/null || printf 'unknown')
|
||||
BUILD_TIME ?= $(call lazy_shell,_BUILD_TIME,date -u +"%Y-%m-%dT%H:%M:%SZ")
|
||||
TARGET_ARCH ?= $(call lazy_shell,_TARGET_ARCH,go env GOARCH)
|
||||
GO_LDFLAGS = -X nex/backend/pkg/buildinfo.version=$(VERSION) -X nex/backend/pkg/buildinfo.commit=$(GIT_COMMIT) -X nex/backend/pkg/buildinfo.buildTime=$(BUILD_TIME)
|
||||
GO_LDFLAGS_WIN = $(GO_LDFLAGS) -H=windowsgui
|
||||
RELEASE_DIR ?= build/release
|
||||
LINUX_DESKTOP_BINARY = build/nex-linux-$(TARGET_ARCH)
|
||||
WINDOWS_DESKTOP_BINARY = build/nex-win-$(TARGET_ARCH).exe
|
||||
WINDOWS_SERVER_BINARY = build/nex-server-windows-$(TARGET_ARCH).exe
|
||||
WINDRES ?= windres
|
||||
|
||||
ifeq ($(TARGET_ARCH),arm64)
|
||||
APPIMAGE_ARCH := aarch64
|
||||
DEB_ARCH := arm64
|
||||
RPM_ARCH := aarch64
|
||||
else
|
||||
APPIMAGE_ARCH := x86_64
|
||||
DEB_ARCH := amd64
|
||||
RPM_ARCH := x86_64
|
||||
endif
|
||||
|
||||
WINDOWS_WINDRES_FORMAT_BFD := pe-x86-64
|
||||
WINDOWS_WINDRES_FORMAT_LLVM := x86_64-w64-mingw32
|
||||
WINDOWS_RESOURCE := rsrc_windows_amd64.syso
|
||||
|
||||
APPIMAGETOOL_PATH := build/tools/appimagetool-$(APPIMAGE_ARCH).AppImage
|
||||
APPIMAGETOOL_URL ?= https://github.com/AppImage/AppImageKit/releases/download/continuous/appimagetool-$(APPIMAGE_ARCH).AppImage
|
||||
APPIMAGETOOL ?= $(APPIMAGETOOL_PATH)
|
||||
|
||||
# ============================================
|
||||
# 全局命令
|
||||
# ============================================
|
||||
|
||||
lint: _backend-lint _frontend-check _versionctl-lint
|
||||
@printf 'Lint complete\n'
|
||||
|
||||
test: _backend-test _frontend-test _desktop-test _versionctl-test
|
||||
@printf 'All tests passed\n'
|
||||
|
||||
clean: _backend-clean _frontend-clean _desktop-clean
|
||||
@printf 'Clean complete\n'
|
||||
|
||||
# ============================================
|
||||
# Git hooks
|
||||
# ============================================
|
||||
|
||||
hooks-install:
|
||||
@hooks_dir=$$(git rev-parse --git-path hooks); \
|
||||
mkdir -p "$$hooks_dir"; \
|
||||
for hook in pre-commit commit-msg prepare-commit-msg; do \
|
||||
src="scripts/git-hooks/$$hook"; \
|
||||
if [ ! -f "$$src" ]; then \
|
||||
printf 'ERROR: source hook not found: %s\n' "$$src" >&2; \
|
||||
exit 1; \
|
||||
fi; \
|
||||
cp "$$src" "$$hooks_dir/$$hook"; \
|
||||
chmod +x "$$hooks_dir/$$hook"; \
|
||||
done; \
|
||||
printf 'Installed Git hooks to %s\n' "$$hooks_dir"
|
||||
|
||||
hooks-check:
|
||||
@hooks_dir=$$(git rev-parse --git-path hooks); \
|
||||
status=0; \
|
||||
for hook in pre-commit commit-msg prepare-commit-msg; do \
|
||||
if [ -x "$$hooks_dir/$$hook" ]; then \
|
||||
printf 'OK: %s\n' "$$hook"; \
|
||||
else \
|
||||
printf 'MISSING: %s (%s/%s)\n' "$$hook" "$$hooks_dir" "$$hook"; \
|
||||
status=1; \
|
||||
fi; \
|
||||
done; \
|
||||
exit $$status
|
||||
|
||||
hooks-test:
|
||||
@scripts/git-hooks/test-hooks.sh
|
||||
|
||||
_hooks-pre-commit:
|
||||
@set -ef; \
|
||||
staged_files=$$(git diff --cached --name-only --diff-filter=ACM); \
|
||||
if [ -z "$$staged_files" ]; then \
|
||||
printf 'No staged files to check\n'; \
|
||||
exit 0; \
|
||||
fi; \
|
||||
run_backend_lint=; \
|
||||
run_versionctl_lint=; \
|
||||
run_frontend_check=; \
|
||||
lfs_patterns=$$(grep 'filter=lfs' .gitattributes 2>/dev/null | awk '{print $$1}' || true); \
|
||||
for file in $$staged_files; do \
|
||||
[ -n "$$file" ] || continue; \
|
||||
if git show ":$$file" 2>/dev/null | grep -Eq '^(<<<<<<<|=======|>>>>>>>)'; then \
|
||||
printf 'Found conflict markers in staged file: %s\n' "$$file" >&2; \
|
||||
printf 'Resolve conflict markers before committing.\n' >&2; \
|
||||
exit 1; \
|
||||
fi; \
|
||||
size=$$(git cat-file -s ":$$file" 2>/dev/null || printf '0'); \
|
||||
if [ "$$size" -gt 512000 ] 2>/dev/null; then \
|
||||
if git show ":$$file" 2>/dev/null | LC_ALL=C grep -Iq .; then \
|
||||
printf 'Warning: large staged text file (%s bytes): %s\n' "$$size" "$$file" >&2; \
|
||||
fi; \
|
||||
fi; \
|
||||
if [ -n "$$lfs_patterns" ]; then \
|
||||
for lfs_pat in $$lfs_patterns; do \
|
||||
case "$$file" in $$lfs_pat) \
|
||||
content=$$(git show ":$$file" 2>/dev/null | head -1); \
|
||||
case "$$content" in \
|
||||
"version https://git-lfs.github.com/spec/v1"*) ;; \
|
||||
*) \
|
||||
printf 'LFS-tracked file not using LFS pointer: %s\n' "$$file" >&2; \
|
||||
printf 'Run "git lfs install" and re-add this file.\n' >&2; \
|
||||
exit 1; \
|
||||
;; \
|
||||
esac; \
|
||||
break; \
|
||||
;; \
|
||||
esac; \
|
||||
done; \
|
||||
fi; \
|
||||
case "$$file" in \
|
||||
backend/*.go) run_backend_lint=1 ;; \
|
||||
versionctl/*.go) run_versionctl_lint=1 ;; \
|
||||
frontend/*.ts|frontend/*.tsx|frontend/*.scss) run_frontend_check=1 ;; \
|
||||
esac; \
|
||||
done; \
|
||||
if [ -n "$$run_backend_lint" ]; then \
|
||||
printf 'Running backend lint...\n'; \
|
||||
$(MAKE) _backend-lint; \
|
||||
fi; \
|
||||
if [ -n "$$run_versionctl_lint" ]; then \
|
||||
printf 'Running versionctl lint...\n'; \
|
||||
$(MAKE) _versionctl-lint; \
|
||||
fi; \
|
||||
if [ -n "$$run_frontend_check" ]; then \
|
||||
printf 'Running frontend check...\n'; \
|
||||
$(MAKE) _frontend-check; \
|
||||
fi; \
|
||||
printf 'Pre-commit checks passed\n'
|
||||
|
||||
# ============================================
|
||||
# 版本管理
|
||||
# ============================================
|
||||
|
||||
version-sync:
|
||||
go run ./versionctl sync
|
||||
|
||||
version-check:
|
||||
go run ./versionctl check
|
||||
|
||||
version-bump: BUMP ?= patch
|
||||
version-bump: lint test _check-clean-worktree
|
||||
@set -e; \
|
||||
bump_arg="$(if $(SET_VERSION),$(SET_VERSION),$(BUMP))"; \
|
||||
new_version=$$(go run ./versionctl bump "$$bump_arg"); \
|
||||
git add VERSION frontend/; \
|
||||
git commit -m "chore: 版本升迁 v$$new_version"; \
|
||||
git tag "v$$new_version"; \
|
||||
printf '版本升迁完成: v%s\n' "$$new_version"
|
||||
|
||||
_check-clean-worktree:
|
||||
@if [ -n "$$(git status --porcelain)" ]; then \
|
||||
printf '工作区不干净,请先提交或清理改动后再执行版本升迁。\n' >&2; \
|
||||
git status --short; \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
# ============================================
|
||||
# Server 模式
|
||||
# ============================================
|
||||
|
||||
server-run:
|
||||
@$(MAKE) -j2 _server-run-backend _server-run-frontend
|
||||
|
||||
server-build: version-check _backend-build _frontend-build
|
||||
@printf 'Server build complete\n'
|
||||
|
||||
server-lint: _backend-lint _frontend-check
|
||||
@printf 'Server lint complete\n'
|
||||
|
||||
server-test: _backend-test _frontend-test
|
||||
@printf 'Server tests passed\n'
|
||||
|
||||
server-clean: _backend-clean _frontend-clean
|
||||
@printf 'Server artifacts cleaned\n'
|
||||
|
||||
_server-run-backend:
|
||||
@$(MAKE) -C backend run
|
||||
|
||||
_server-run-frontend: _frontend-install
|
||||
cd frontend && bun run dev
|
||||
|
||||
# ============================================
|
||||
# Desktop 模式
|
||||
# ============================================
|
||||
|
||||
desktop-build-mac: version-check _desktop-prepare-frontend _desktop-prepare-embedfs
|
||||
@printf 'Building macOS desktop...\n'
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-mac-arm64 ./cmd/desktop
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-mac-amd64 ./cmd/desktop
|
||||
lipo -create build/nex-mac-arm64 build/nex-mac-amd64 -output build/nex-mac-universal
|
||||
lipo -info build/nex-mac-universal | grep -q 'x86_64 arm64'
|
||||
rm -f build/nex-mac-arm64 build/nex-mac-amd64
|
||||
@printf 'Packaging macOS app bundle...\n'
|
||||
rm -rf build/Nex.app
|
||||
mkdir -p build/Nex.app/Contents/MacOS build/Nex.app/Contents/Resources
|
||||
cp build/nex-mac-universal build/Nex.app/Contents/MacOS/nex
|
||||
@if [ -f assets/icon.icns ]; then \
|
||||
cp assets/icon.icns build/Nex.app/Contents/Resources/; \
|
||||
else \
|
||||
printf 'Missing assets/icon.icns\n'; \
|
||||
exit 1; \
|
||||
fi
|
||||
@MIN_MACOS_VERSION=$$(vtool -show-build build/nex-mac-universal | awk '/minos / {print $$2; exit}'); \
|
||||
if [ -z "$$MIN_MACOS_VERSION" ]; then \
|
||||
printf 'Unable to read macOS minimum version\n'; \
|
||||
exit 1; \
|
||||
fi; \
|
||||
go run ./versionctl macos-plist "$$MIN_MACOS_VERSION" > build/Nex.app/Contents/Info.plist
|
||||
chmod +x build/Nex.app/Contents/MacOS/nex
|
||||
@printf 'macOS desktop build complete\n'
|
||||
|
||||
desktop-build-win: version-check _desktop-prepare-frontend _desktop-prepare-embedfs _desktop-prepare-windows-resource _check-windows-target-arch
|
||||
@printf 'Building Windows desktop $(TARGET_ARCH)...\n'
|
||||
mkdir -p build
|
||||
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=$(TARGET_ARCH) go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../$(WINDOWS_DESKTOP_BINARY) ./cmd/desktop
|
||||
@printf 'Windows desktop build complete\n'
|
||||
|
||||
desktop-build-linux: version-check _desktop-prepare-frontend _desktop-prepare-embedfs _check-linux-target-arch
|
||||
@printf 'Building Linux desktop $(TARGET_ARCH)...\n'
|
||||
mkdir -p build
|
||||
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=$(TARGET_ARCH) go build -ldflags "$(GO_LDFLAGS)" -o ../$(LINUX_DESKTOP_BINARY) ./cmd/desktop
|
||||
@printf 'Linux desktop build complete\n'
|
||||
|
||||
desktop-lint: _backend-lint _frontend-check
|
||||
@printf 'Desktop lint complete\n'
|
||||
|
||||
desktop-test: _desktop-test
|
||||
@printf 'Desktop tests passed\n'
|
||||
|
||||
desktop-clean: _desktop-clean
|
||||
@printf 'Desktop artifacts cleaned\n'
|
||||
|
||||
_desktop-test:
|
||||
cd backend && go test ./cmd/desktop/... -v
|
||||
|
||||
_desktop-clean:
|
||||
rm -rf build/ embedfs/assets embedfs/frontend-dist backend/cmd/desktop/rsrc_windows_amd64.syso
|
||||
|
||||
_desktop-prepare-frontend: _frontend-install
|
||||
@printf 'Preparing frontend for desktop...\n'
|
||||
cd frontend && cp .env.desktop .env.production.local
|
||||
cd frontend && bun run build
|
||||
rm -f frontend/.env.production.local
|
||||
|
||||
_desktop-prepare-embedfs:
|
||||
@printf 'Preparing embedded filesystem...\n'
|
||||
rm -rf embedfs/assets embedfs/frontend-dist
|
||||
cp -r assets embedfs/assets
|
||||
cp -r frontend/dist embedfs/frontend-dist
|
||||
|
||||
_desktop-prepare-windows-resource: _check-windows-target-arch
|
||||
@printf 'Preparing Windows $(TARGET_ARCH) executable icon...\n'
|
||||
@WINDRES_CMD="$(WINDRES)"; \
|
||||
WINDRES_FMT="$(WINDOWS_WINDRES_FORMAT_BFD)"; \
|
||||
if command -v llvm-windres >/dev/null 2>&1; then \
|
||||
WINDRES_CMD=llvm-windres; \
|
||||
WINDRES_FMT="$(WINDOWS_WINDRES_FORMAT_LLVM)"; \
|
||||
elif "$$WINDRES_CMD" --version 2>&1 | grep -qi LLVM; then \
|
||||
WINDRES_FMT="$(WINDOWS_WINDRES_FORMAT_LLVM)"; \
|
||||
fi; \
|
||||
command -v "$$WINDRES_CMD" >/dev/null 2>&1 || { printf 'Missing windres tool: %s\n' "$$WINDRES_CMD"; exit 1; }; \
|
||||
cd backend/cmd/desktop && "$$WINDRES_CMD" -O coff -F "$$WINDRES_FMT" -i icon_windows.rc -o $(WINDOWS_RESOURCE)
|
||||
|
||||
# ============================================
|
||||
# 发布资产
|
||||
# ============================================
|
||||
|
||||
release-assets-check:
|
||||
go run ./versionctl release-assets-check
|
||||
@printf 'Release assets check passed\n'
|
||||
|
||||
release-assets-web: version-check release-assets-check _frontend-build
|
||||
rm -rf "$(RELEASE_DIR)"
|
||||
mkdir -p "$(RELEASE_DIR)"
|
||||
asset=$$(go run ./versionctl asset-name web tar.gz); \
|
||||
tar -C frontend -czf "$(RELEASE_DIR)/$$asset" dist
|
||||
|
||||
release-assets-linux: version-check release-assets-check _check-linux-target-arch
|
||||
rm -rf "$(RELEASE_DIR)"
|
||||
mkdir -p "$(RELEASE_DIR)"
|
||||
@$(MAKE) release-assets-server-linux TARGET_ARCH=$(TARGET_ARCH) RELEASE_DIR="$(RELEASE_DIR)"
|
||||
@$(MAKE) release-assets-desktop-linux TARGET_ARCH=$(TARGET_ARCH) RELEASE_DIR="$(RELEASE_DIR)"
|
||||
|
||||
release-assets-windows: version-check release-assets-check _check-windows-target-arch
|
||||
rm -rf "$(RELEASE_DIR)"
|
||||
mkdir -p "$(RELEASE_DIR)"
|
||||
@$(MAKE) release-assets-server-windows TARGET_ARCH=$(TARGET_ARCH) RELEASE_DIR="$(RELEASE_DIR)"
|
||||
@$(MAKE) release-assets-desktop-windows TARGET_ARCH=$(TARGET_ARCH) RELEASE_DIR="$(RELEASE_DIR)"
|
||||
|
||||
release-assets-macos: version-check release-assets-check
|
||||
rm -rf "$(RELEASE_DIR)"
|
||||
mkdir -p "$(RELEASE_DIR)"
|
||||
@$(MAKE) release-assets-server-macos RELEASE_DIR="$(RELEASE_DIR)"
|
||||
@$(MAKE) release-assets-desktop-macos RELEASE_DIR="$(RELEASE_DIR)"
|
||||
|
||||
release-assets-server-linux: version-check _check-linux-target-arch
|
||||
mkdir -p build "$(RELEASE_DIR)"
|
||||
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=$(TARGET_ARCH) go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-linux-$(TARGET_ARCH) ./cmd/server
|
||||
asset=$$(go run ./versionctl asset-name server linux $(TARGET_ARCH) tar.gz); \
|
||||
tar -C build -czf "$(RELEASE_DIR)/$$asset" nex-server-linux-$(TARGET_ARCH)
|
||||
|
||||
release-assets-server-windows: version-check _check-windows-target-arch
|
||||
mkdir -p build "$(RELEASE_DIR)"
|
||||
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=$(TARGET_ARCH) go build -ldflags "$(GO_LDFLAGS)" -o ../$(WINDOWS_SERVER_BINARY) ./cmd/server
|
||||
asset=$$(go run ./versionctl asset-name server windows $(TARGET_ARCH) zip); \
|
||||
if command -v powershell.exe >/dev/null 2>&1; then POWERSHELL=powershell.exe; else POWERSHELL=powershell; fi; \
|
||||
"$$POWERSHELL" -NoProfile -Command "Compress-Archive -LiteralPath '$(WINDOWS_SERVER_BINARY)' -DestinationPath '$(RELEASE_DIR)/$$asset' -Force"
|
||||
|
||||
release-assets-server-macos: version-check
|
||||
mkdir -p build "$(RELEASE_DIR)"
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-macos-amd64 ./cmd/server
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-macos-arm64 ./cmd/server
|
||||
lipo -create build/nex-server-macos-amd64 build/nex-server-macos-arm64 -output build/nex-server-macos-universal
|
||||
lipo -info build/nex-server-macos-universal | grep -q 'x86_64 arm64'
|
||||
asset=$$(go run ./versionctl asset-name server macos amd64 tar.gz); \
|
||||
tar -C build -czf "$(RELEASE_DIR)/$$asset" nex-server-macos-amd64
|
||||
asset=$$(go run ./versionctl asset-name server macos arm64 tar.gz); \
|
||||
tar -C build -czf "$(RELEASE_DIR)/$$asset" nex-server-macos-arm64
|
||||
asset=$$(go run ./versionctl asset-name server macos universal tar.gz); \
|
||||
tar -C build -czf "$(RELEASE_DIR)/$$asset" nex-server-macos-universal
|
||||
rm -f build/nex-server-macos-amd64 build/nex-server-macos-arm64 build/nex-server-macos-universal
|
||||
|
||||
release-assets-desktop-linux: version-check release-assets-check desktop-build-linux _package-linux-tar _package-linux-appimage _package-linux-deb _package-linux-rpm
|
||||
|
||||
release-assets-desktop-windows: version-check release-assets-check desktop-build-win
|
||||
mkdir -p "$(RELEASE_DIR)"
|
||||
asset=$$(go run ./versionctl asset-name desktop windows $(TARGET_ARCH) zip); \
|
||||
if command -v powershell.exe >/dev/null 2>&1; then POWERSHELL=powershell.exe; else POWERSHELL=powershell; fi; \
|
||||
"$$POWERSHELL" -NoProfile -Command "Compress-Archive -LiteralPath '$(WINDOWS_DESKTOP_BINARY)' -DestinationPath '$(RELEASE_DIR)/$$asset' -Force"
|
||||
|
||||
release-assets-desktop-macos: version-check release-assets-check desktop-build-mac _package-macos-zip _package-macos-dmg
|
||||
rm -rf build/Nex.app build/dmg
|
||||
|
||||
release-assets-checksums:
|
||||
@cd "$(RELEASE_DIR)" && \
|
||||
rm -f SHA256SUMS && \
|
||||
for asset in *; do \
|
||||
[ -f "$$asset" ] || continue; \
|
||||
if command -v sha256sum >/dev/null 2>&1; then \
|
||||
sha256sum "$$asset"; \
|
||||
elif command -v shasum >/dev/null 2>&1; then \
|
||||
shasum -a 256 "$$asset"; \
|
||||
else \
|
||||
printf 'Missing sha256sum or shasum\n' >&2; \
|
||||
exit 1; \
|
||||
fi; \
|
||||
done > SHA256SUMS && \
|
||||
test -s SHA256SUMS
|
||||
|
||||
_check-linux-target-arch:
|
||||
@if [ "$(TARGET_ARCH)" != "amd64" ] && [ "$(TARGET_ARCH)" != "arm64" ]; then \
|
||||
printf 'Unsupported Linux TARGET_ARCH: %s\n' "$(TARGET_ARCH)"; \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
_check-windows-target-arch:
|
||||
@if [ "$(TARGET_ARCH)" != "amd64" ]; then \
|
||||
printf 'Unsupported Windows TARGET_ARCH: %s\n' "$(TARGET_ARCH)"; \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
_ensure-appimagetool:
|
||||
@mkdir -p build/tools
|
||||
@if [ ! -x "$(APPIMAGETOOL)" ]; then \
|
||||
printf 'Downloading appimagetool for %s...\n' "$(APPIMAGE_ARCH)"; \
|
||||
command -v curl >/dev/null 2>&1 || { printf 'Missing curl for appimagetool download\n'; exit 1; }; \
|
||||
curl -L "$(APPIMAGETOOL_URL)" -o "$(APPIMAGETOOL)"; \
|
||||
chmod +x "$(APPIMAGETOOL)"; \
|
||||
fi; \
|
||||
printf 'Using appimagetool: %s\n' "$(APPIMAGETOOL)"
|
||||
|
||||
_package-linux-tar:
|
||||
mkdir -p "$(RELEASE_DIR)"
|
||||
asset=$$(go run ./versionctl asset-name desktop linux $(TARGET_ARCH) tar.gz); \
|
||||
tar -C build -czf "$(RELEASE_DIR)/$$asset" nex-linux-$(TARGET_ARCH)
|
||||
|
||||
_package-linux-appimage: _ensure-appimagetool
|
||||
mkdir -p "$(RELEASE_DIR)"
|
||||
appdir="build/appimage/nex-$(TARGET_ARCH).AppDir"; \
|
||||
rm -rf "$$appdir"; \
|
||||
mkdir -p "$$appdir/usr/bin" "$$appdir/usr/share/applications" "$$appdir/usr/share/icons"; \
|
||||
install -m 0755 "$(LINUX_DESKTOP_BINARY)" "$$appdir/usr/bin/nex"; \
|
||||
install -m 0644 packaging/linux/nex.desktop "$$appdir/nex.desktop"; \
|
||||
install -m 0644 packaging/linux/nex.desktop "$$appdir/usr/share/applications/nex.desktop"; \
|
||||
install -m 0755 packaging/linux/AppRun "$$appdir/AppRun"; \
|
||||
cp -R assets/icons/hicolor "$$appdir/usr/share/icons/"; \
|
||||
cp assets/icon.png "$$appdir/nex.png"; \
|
||||
asset=$$(go run ./versionctl asset-name desktop linux $(TARGET_ARCH) AppImage); \
|
||||
ARCH=$(APPIMAGE_ARCH) APPIMAGE_EXTRACT_AND_RUN=1 "$(APPIMAGETOOL)" "$$appdir" "$(RELEASE_DIR)/$$asset"; \
|
||||
chmod +x "$(RELEASE_DIR)/$$asset"; \
|
||||
test -s "$(RELEASE_DIR)/$$asset"
|
||||
|
||||
_package-linux-deb:
|
||||
mkdir -p "$(RELEASE_DIR)"
|
||||
pkgdir="build/pkg/deb/nex-$(TARGET_ARCH)"; \
|
||||
rm -rf "$$pkgdir"; \
|
||||
mkdir -p "$$pkgdir/DEBIAN" "$$pkgdir/usr/bin" "$$pkgdir/usr/share/applications" "$$pkgdir/usr/share/icons"; \
|
||||
install -m 0755 "$(LINUX_DESKTOP_BINARY)" "$$pkgdir/usr/bin/nex"; \
|
||||
install -m 0644 packaging/linux/nex.desktop "$$pkgdir/usr/share/applications/nex.desktop"; \
|
||||
cp -R assets/icons/hicolor "$$pkgdir/usr/share/icons/"; \
|
||||
printf '%s\n' \
|
||||
'Package: nex' \
|
||||
'Version: $(VERSION)' \
|
||||
'Section: utils' \
|
||||
'Priority: optional' \
|
||||
'Architecture: $(DEB_ARCH)' \
|
||||
'Maintainer: Nex Maintainers <noreply@example.com>' \
|
||||
'Depends: libgtk-3-0, libayatana-appindicator3-1, xdg-utils' \
|
||||
'Description: AI Gateway desktop application' \
|
||||
' Nex is an AI Gateway desktop application.' \
|
||||
> "$$pkgdir/DEBIAN/control"; \
|
||||
asset=$$(go run ./versionctl asset-name desktop linux $(TARGET_ARCH) deb); \
|
||||
dpkg-deb --build --root-owner-group "$$pkgdir" "$(RELEASE_DIR)/$$asset"; \
|
||||
dpkg-deb -I "$(RELEASE_DIR)/$$asset" >/dev/null
|
||||
|
||||
_package-linux-rpm:
|
||||
mkdir -p "$(RELEASE_DIR)"
|
||||
topdir="$(abspath build/rpmbuild-$(TARGET_ARCH))"; \
|
||||
rm -rf "$$topdir"; \
|
||||
mkdir -p "$$topdir/BUILD" "$$topdir/BUILDROOT" "$$topdir/RPMS" "$$topdir/SOURCES" "$$topdir/SPECS" "$$topdir/SRPMS"; \
|
||||
rpmbuild -bb --target "$(RPM_ARCH)" \
|
||||
--define "_topdir $$topdir" \
|
||||
--define "nex_version $(VERSION)" \
|
||||
--define "nex_binary $(abspath $(LINUX_DESKTOP_BINARY))" \
|
||||
--define "nex_desktop_file $(abspath packaging/linux/nex.desktop)" \
|
||||
--define "nex_icons_dir $(abspath assets/icons/hicolor)" \
|
||||
packaging/linux/nex.spec; \
|
||||
rpm_file=$$(find "$$topdir/RPMS" -type f -name '*.rpm' | sort | tail -n 1); \
|
||||
test -n "$$rpm_file"; \
|
||||
asset=$$(go run ./versionctl asset-name desktop linux $(TARGET_ARCH) rpm); \
|
||||
cp "$$rpm_file" "$(RELEASE_DIR)/$$asset"; \
|
||||
rpm -qip "$(RELEASE_DIR)/$$asset" >/dev/null
|
||||
|
||||
_package-macos-zip:
|
||||
mkdir -p "$(RELEASE_DIR)"
|
||||
asset=$$(go run ./versionctl asset-name desktop macos universal zip); \
|
||||
ditto -c -k --keepParent build/Nex.app "$(RELEASE_DIR)/$$asset"
|
||||
|
||||
_package-macos-dmg:
|
||||
mkdir -p "$(RELEASE_DIR)"
|
||||
dmgdir="build/dmg/Nex"; \
|
||||
rm -rf "$$dmgdir"; \
|
||||
mkdir -p "$$dmgdir"; \
|
||||
cp -R build/Nex.app "$$dmgdir/Nex.app"; \
|
||||
ln -s /Applications "$$dmgdir/Applications"; \
|
||||
asset=$$(go run ./versionctl asset-name desktop macos universal dmg); \
|
||||
hdiutil create -volname Nex -srcfolder "$$dmgdir" -ov -format UDZO "$(RELEASE_DIR)/$$asset"; \
|
||||
hdiutil verify "$(RELEASE_DIR)/$$asset" >/dev/null && \
|
||||
rm -rf "$$dmgdir"
|
||||
|
||||
# ============================================
|
||||
# 共享 helper targets
|
||||
# ============================================
|
||||
|
||||
_backend-build:
|
||||
@$(MAKE) -C backend build
|
||||
|
||||
_backend-lint:
|
||||
@$(MAKE) -C backend lint
|
||||
|
||||
_backend-test:
|
||||
@$(MAKE) -C backend test
|
||||
|
||||
_backend-clean:
|
||||
@$(MAKE) -C backend clean
|
||||
|
||||
_versionctl-lint:
|
||||
@$(MAKE) -C versionctl lint
|
||||
|
||||
_versionctl-test:
|
||||
@$(MAKE) -C versionctl test
|
||||
|
||||
_frontend-install:
|
||||
cd frontend && bun install
|
||||
|
||||
_frontend-build: _frontend-install
|
||||
cd frontend && bun run build
|
||||
|
||||
_frontend-check: _frontend-install
|
||||
cd frontend && bun run check
|
||||
|
||||
_frontend-test: _frontend-install
|
||||
cd frontend && bun run test
|
||||
|
||||
_frontend-dev: _frontend-install
|
||||
cd frontend && bun run dev
|
||||
|
||||
_frontend-clean:
|
||||
rm -rf frontend/dist frontend/.next frontend/coverage frontend/playwright-report frontend/test-results frontend/tsconfig.tsbuildinfo
|
||||
363
README.md
363
README.md
@@ -7,13 +7,15 @@
|
||||
```
|
||||
nex/
|
||||
├── backend/ # Go 后端服务(分层架构)
|
||||
│ ├── cmd/server/ # 主程序入口
|
||||
│ ├── cmd/
|
||||
│ │ ├── server/ # CLI 主程序入口
|
||||
│ │ └── desktop/ # 桌面应用入口
|
||||
│ ├── internal/
|
||||
│ │ ├── handler/ # HTTP 处理器 + 中间件
|
||||
│ │ ├── service/ # 业务逻辑层
|
||||
│ │ ├── repository/ # 数据访问层
|
||||
│ │ ├── domain/ # 领域模型
|
||||
│ │ ├── protocol/ # 协议适配器(OpenAI/Anthropic)
|
||||
│ │ ├── conversion/ # 协议转换引擎(OpenAI/Anthropic 适配器)
|
||||
│ │ ├── provider/ # 供应商客户端
|
||||
│ │ └── config/ # 配置管理
|
||||
│ ├── pkg/ # 公共包(logger/errors/validator)
|
||||
@@ -25,25 +27,37 @@ nex/
|
||||
│ │ ├── api/ # API 层(统一请求封装 + 字段转换)
|
||||
│ │ ├── hooks/ # TanStack Query hooks
|
||||
│ │ ├── components/ # 通用组件(AppLayout)
|
||||
│ │ ├── pages/ # 页面(Providers, Stats)
|
||||
│ │ ├── pages/ # 页面(Providers, Stats, Settings)
|
||||
│ │ ├── routes/ # React Router 路由配置
|
||||
│ │ ├── types/ # TypeScript 类型定义
|
||||
│ │ └── __tests__/ # 单元测试 + 组件测试
|
||||
│ ├── e2e/ # Playwright E2E 测试
|
||||
│ └── package.json
|
||||
│
|
||||
├── assets/ # 应用资源
|
||||
│ ├── icon.png # 托盘图标
|
||||
│ ├── icon.icns # macOS 应用图标
|
||||
│ └── icon.ico # Windows 应用图标
|
||||
│
|
||||
├── packaging/ # 桌面发布包元数据(Linux desktop entry、RPM spec 等)
|
||||
│
|
||||
└── README.md # 本文件
|
||||
```
|
||||
|
||||
## 功能特性
|
||||
|
||||
- **双协议支持**:同时支持 OpenAI 和 Anthropic 协议
|
||||
- **透明代理**:对 OpenAI 兼容供应商透传请求
|
||||
- **流式响应**:完整支持 SSE 流式传输
|
||||
- **跨协议转换**:Hub-and-Spoke 架构实现 OpenAI ↔ Anthropic 双向转换
|
||||
- **统一模型 ID**:`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4`)
|
||||
- **Smart Passthrough**:同协议请求跳过 Canonical 全量转换,仅在 JSON 层改写 model 字段
|
||||
- **流式响应**:完整支持 SSE 流式传输,包括跨协议流式转换
|
||||
- **Function Calling**:支持工具调用(Tools)
|
||||
- **多供应商管理**:配置和管理多个供应商
|
||||
- **Thinking / Reasoning**:支持 OpenAI `reasoning_effort` 和 Anthropic `thinking` 配置
|
||||
- **扩展接口**:支持 Embeddings 和 Rerank 接口
|
||||
- **多供应商管理**:配置和管理多个供应商(供应商 ID 仅限字母、数字、下划线)
|
||||
- **用量统计**:按供应商、模型、日期统计请求数量
|
||||
- **Web 配置界面**:提供供应商和模型配置管理
|
||||
- **启动参数设置**:通过 Web 界面查看和编辑启动参数(Desktop 可编辑、Server 只读)
|
||||
|
||||
## 技术栈
|
||||
|
||||
@@ -51,12 +65,26 @@ nex/
|
||||
- **语言**: Go 1.26+
|
||||
- **HTTP 框架**: Gin
|
||||
- **ORM**: GORM
|
||||
- **数据库**: SQLite
|
||||
- **日志**: zap + lumberjack(结构化日志 + 日志轮转)
|
||||
- **配置**: gopkg.in/yaml.v3
|
||||
- **数据库**: SQLite / MySQL
|
||||
- **日志**: zap + lumberjack(结构化日志 + 日志轮转 + 模块标识)
|
||||
- **配置**: Viper + pflag(Server 多层配置,Desktop 配置文件快照)
|
||||
- **验证**: go-playground/validator/v10
|
||||
- **迁移**: goose
|
||||
|
||||
#### 日志模块标识规范
|
||||
|
||||
每个模块通过依赖注入获取带模块标识的 logger,日志输出格式为 `[module.name]`:
|
||||
|
||||
```
|
||||
Console: INFO [handler.proxy] 处理请求 method=POST path=/v1/chat
|
||||
JSON: {"level":"info","logger":"handler.proxy","msg":"处理请求","method":"POST"}
|
||||
```
|
||||
|
||||
模块命名规范:
|
||||
- 单一职责包:`database`、`config`
|
||||
- 多实体包:`handler.proxy`、`service.provider`
|
||||
- 子包:`handler.middleware`
|
||||
|
||||
### 前端
|
||||
- **运行时**: Bun
|
||||
- **构建工具**: Vite
|
||||
@@ -66,56 +94,144 @@ nex/
|
||||
- **图表库**: Recharts
|
||||
- **路由**: React Router v7
|
||||
- **数据获取**: TanStack Query v5
|
||||
- **样式**: SCSS Modules
|
||||
- **样式**: TDesign 组件 props 优先,TDesign tokens 次之,SCSS 作为兜底补充
|
||||
- **测试**: Vitest + React Testing Library + Playwright
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 后端
|
||||
### 桌面应用(推荐)
|
||||
|
||||
**构建桌面应用**:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
go mod download
|
||||
go run cmd/server/main.go
|
||||
# macOS (arm64 + amd64,并打包为 .app)
|
||||
make desktop-build-mac
|
||||
|
||||
# Windows
|
||||
make desktop-build-win
|
||||
|
||||
# Linux
|
||||
make desktop-build-linux
|
||||
|
||||
# Linux arm64
|
||||
make desktop-build-linux TARGET_ARCH=arm64
|
||||
```
|
||||
|
||||
后端服务将在 `http://localhost:9826` 启动。首次启动会自动:
|
||||
- 创建配置文件 `~/.nex/config.yaml`
|
||||
**使用桌面应用**:
|
||||
- 双击启动应用(macOS: Nex.app,Windows: nex-win-amd64.exe,Linux: nex-linux-amd64 / nex-linux-arm64)
|
||||
- 系统托盘图标出现,浏览器自动打开管理界面
|
||||
- 点击托盘图标显示菜单,可打开管理界面或退出
|
||||
- 关闭浏览器后服务继续运行,可通过托盘重新打开
|
||||
|
||||
**注意事项**:
|
||||
- 桌面应用需要 CGO 支持
|
||||
- macOS: 自带 Xcode Command Line Tools
|
||||
- Linux 构建: 需要 gcc、pkg-config、GTK3 开发包和 Ayatana AppIndicator 开发包(Ubuntu/Debian: `libgtk-3-dev`、`libayatana-appindicator3-dev`)
|
||||
- Linux 运行: 需要 GTK3、Ayatana AppIndicator 和 xdg-utils;AppImage 也依赖系统提供 AppImage runtime/FUSE 能力,不承诺完全自包含
|
||||
- Windows: 需要对应架构的 MinGW-w64/MSYS2 工具链,desktop 使用 GUI linker flags 隐藏控制台窗口
|
||||
- macOS DMG: 发布包暂不签名、不 notarize,首次打开可能出现 Gatekeeper 提示
|
||||
|
||||
**Linux 桌面环境兼容性**:
|
||||
- GNOME: 需要 AppIndicator 扩展
|
||||
- KDE Plasma: 原生支持
|
||||
- Xfce: 需要 libappindicator
|
||||
- 其他支持 StatusNotifierItem 规范的环境
|
||||
|
||||
### Server 模式(前后端分离)
|
||||
|
||||
```bash
|
||||
make server-run
|
||||
```
|
||||
|
||||
`make server-run` 会并行启动:
|
||||
- 后端服务:`http://localhost:9826`
|
||||
- 前端开发服务器:`http://localhost:5173`
|
||||
|
||||
前端请求会继续通过 Vite proxy 转发到后端。后端首次启动会自动:
|
||||
- 初始化数据库 `~/.nex/config.db`
|
||||
- 运行数据库迁移
|
||||
- 创建日志目录 `~/.nex/log/`
|
||||
|
||||
### 前端
|
||||
**构建 server 模式产物**:
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
bun install
|
||||
bun dev
|
||||
make server-build
|
||||
```
|
||||
|
||||
前端开发服务器将在 `http://localhost:5173` 启动,API 请求通过 Vite proxy 转发到后端。
|
||||
### Release 产物
|
||||
|
||||
发布流程由 Git tag `vX.Y.Z` 触发,GitHub Actions 会先通过全流程测试门禁,再构建并创建 Draft Release,上传 server、web 和 desktop 三类产物,同时生成 `SHA256SUMS`。
|
||||
|
||||
**server 产物**(不内置 Web 管理界面):
|
||||
|
||||
| 平台 | 产物 |
|
||||
|------|------|
|
||||
| Linux amd64 | `nex-server_<version>_linux_amd64.tar.gz` |
|
||||
| Linux arm64 | `nex-server_<version>_linux_arm64.tar.gz` |
|
||||
| macOS amd64 | `nex-server_<version>_macos_amd64.tar.gz` |
|
||||
| macOS arm64 | `nex-server_<version>_macos_arm64.tar.gz` |
|
||||
| macOS universal | `nex-server_<version>_macos_universal.tar.gz` |
|
||||
| Windows amd64 | `nex-server_<version>_windows_amd64.zip` |
|
||||
|
||||
**web 产物**:
|
||||
|
||||
| 内容 | 产物 |
|
||||
|------|------|
|
||||
| `frontend/dist` | `nex-web_<version>.tar.gz` |
|
||||
|
||||
**desktop 产物**:
|
||||
|
||||
| 平台 | 产物 |
|
||||
|------|------|
|
||||
| Linux amd64 | `nex-desktop_<version>_linux_amd64.tar.gz`、`.AppImage`、`.deb`、`.rpm` |
|
||||
| Linux arm64 | `nex-desktop_<version>_linux_arm64.tar.gz`、`.AppImage`、`.deb`、`.rpm` |
|
||||
| macOS universal | `nex-desktop_<version>_macos_universal.zip`、`nex-desktop_<version>_macos_universal.dmg` |
|
||||
| Windows amd64 | `nex-desktop_<version>_windows_amd64.zip` |
|
||||
|
||||
Linux deb 包声明 `libgtk-3-0`、`libayatana-appindicator3-1`、`xdg-utils` 运行依赖;rpm 包声明 `gtk3`、`libayatana-appindicator-gtk3`、`xdg-utils` 运行依赖。Rocky Linux 9 等发行版可能需要启用 EPEL 才能解析 Ayatana AppIndicator 依赖。
|
||||
|
||||
server 和 desktop 发布产物自包含运行时数据库迁移资源(通过 `go:embed` 嵌入二进制),安装后首次启动不再依赖仓库源码目录。
|
||||
|
||||
## API 接口
|
||||
|
||||
### 代理接口(对外部应用)
|
||||
|
||||
- `POST /v1/chat/completions` - OpenAI Chat Completions API
|
||||
- `POST /v1/messages` - Anthropic Messages API
|
||||
代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough,最小化 JSON 改写并保持未改写字段的 JSON 内容和类型不变;跨协议请求走完整 decode/encode 转换。
|
||||
|
||||
**OpenAI 协议**(`protocol=openai`):
|
||||
- `POST /openai/v1/chat/completions` - 对话补全
|
||||
- `GET /openai/v1/models` - 模型列表(本地数据库聚合)
|
||||
- `GET /openai/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
||||
- `POST /openai/v1/embeddings` - 嵌入
|
||||
- `POST /openai/v1/rerank` - 重排序
|
||||
|
||||
**Anthropic 协议**(`protocol=anthropic`):
|
||||
- `POST /anthropic/v1/messages` - 消息对话
|
||||
- `GET /anthropic/v1/models` - 模型列表(本地数据库聚合)
|
||||
- `GET /anthropic/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
||||
|
||||
路径边界:网关只剥离第一段协议前缀,剩余路径保持协议原生形态交给 adapter。OpenAI adapter 接收 `/v1/chat/completions`、`/v1/models`、`/v1/embeddings`、`/v1/rerank`,并在构建上游 URL 时去掉 `/v1`;Anthropic adapter 接收 `/v1/messages`、`/v1/models`。因此 OpenAI 供应商 `base_url` 配置到版本路径一级(如 `https://api.openai.com/v1`),Anthropic 供应商 `base_url` 配置到域名级(如 `https://api.anthropic.com`)。
|
||||
|
||||
代理错误边界:网关层错误统一返回 `{"error":"...","code":"..."}`,例如 `INVALID_JSON`、`MODEL_NOT_FOUND`、`CONVERSION_FAILED`、`UPSTREAM_UNAVAILABLE`。只要上游已经返回 HTTP 响应,非 2xx 的 status、过滤 hop-by-hop header 后的 headers 和 body 会直接透传,不包装为应用错误或协议错误。
|
||||
|
||||
模型路由边界:只有 adapter 明确适配的接口会解析请求体中的 `model` 并使用统一模型 ID 路由;未知接口即使包含顶层 `model` 也按无 model 透传处理。
|
||||
|
||||
流式边界:同协议无响应 model 改写时原样透传 SSE frame 和 `[DONE]`;同协议需要响应 model 改写时只解析 SSE frame 的 `data` JSON 并改写 `model`;跨协议流式仍走 provider decoder → Canonical stream event → client encoder。
|
||||
|
||||
### 管理接口(对前端)
|
||||
|
||||
#### 供应商管理
|
||||
- `GET /api/providers` - 列出所有供应商
|
||||
- `POST /api/providers` - 创建供应商
|
||||
- `POST /api/providers` - 创建供应商(`id` 仅限字母、数字、下划线,长度 1-64)
|
||||
- `GET /api/providers/:id` - 获取供应商
|
||||
- `PUT /api/providers/:id` - 更新供应商
|
||||
- `PUT /api/providers/:id` - 更新供应商(`id` 不可修改)
|
||||
- `DELETE /api/providers/:id` - 删除供应商
|
||||
|
||||
#### 模型管理
|
||||
- `GET /api/models` - 列出模型(支持 `?provider_id=xxx` 过滤)
|
||||
- `POST /api/models` - 创建模型
|
||||
- `GET /api/models/:id` - 获取模型
|
||||
- `PUT /api/models/:id` - 更新模型
|
||||
- `POST /api/models` - 创建模型(`id` 由系统自动生成 UUID,`provider_id` + `model_name` 联合唯一)
|
||||
- `GET /api/models/:id` - 获取模型(响应含 `unified_id` 字段,格式 `provider_id/model_name`)
|
||||
- `PUT /api/models/:id` - 更新模型(不可修改 `id`)
|
||||
- `DELETE /api/models/:id` - 删除模型
|
||||
|
||||
#### 统计查询
|
||||
@@ -124,9 +240,29 @@ bun dev
|
||||
|
||||
查询参数支持:`provider_id`、`model_name`、`start_date`、`end_date`、`group_by`
|
||||
|
||||
#### 启动参数设置
|
||||
- `GET /api/settings/startup` - 查询启动参数设置
|
||||
- `PUT /api/settings/startup` - 保存启动参数设置(仅 Desktop 模式)
|
||||
|
||||
**行为差异**:
|
||||
- **Desktop 模式**:查询返回配置文件编辑视图(`~/.nex/config.yaml` + 默认值),允许保存到配置文件,保存后当前运行服务不受影响,需重启 Desktop 生效
|
||||
- **Server 模式**:查询返回当前运行有效配置,保存请求始终返回 403
|
||||
|
||||
响应包含 `mode`、`editable`、`config_path`、`restart_required` 元数据和完整启动参数配置。Duration 字段使用字符串格式(如 `30s`、`1h`)。
|
||||
|
||||
#### 版本信息
|
||||
- `GET /api/version` - 获取后端构建版本信息(`version`、`commit`、`build_time`),用于前端 About 页面诊断前后端版本一致性
|
||||
|
||||
## 配置
|
||||
|
||||
配置文件位于 `~/.nex/config.yaml`,首次启动自动生成:
|
||||
配置方式取决于启动模式:
|
||||
|
||||
- **Server 模式**(`cmd/server`):支持 CLI 参数 > 环境变量 > 配置文件 > 默认值
|
||||
- **Desktop 模式**(`cmd/desktop`):仅支持配置文件 `~/.nex/config.yaml` > 默认值,修改配置文件后需重启 desktop 生效
|
||||
|
||||
### 配置文件
|
||||
|
||||
配置文件位于 `~/.nex/config.yaml`。配置文件不存在时使用默认值,不会自动生成;需要自定义时手动创建该文件:
|
||||
|
||||
```yaml
|
||||
server:
|
||||
@@ -135,7 +271,14 @@ server:
|
||||
write_timeout: 30s
|
||||
|
||||
database:
|
||||
path: ~/.nex/config.db
|
||||
driver: sqlite # sqlite 或 mysql
|
||||
path: ~/.nex/config.db # SQLite 数据库文件路径
|
||||
# --- MySQL 配置(driver=mysql 时生效)---
|
||||
# host: localhost
|
||||
# port: 3306
|
||||
# user: nex
|
||||
# password: ""
|
||||
# dbname: nex
|
||||
max_idle_conns: 10
|
||||
max_open_conns: 100
|
||||
conn_max_lifetime: 1h
|
||||
@@ -149,50 +292,156 @@ log:
|
||||
compress: true
|
||||
```
|
||||
|
||||
数据文件:
|
||||
### 环境变量(仅 Server 模式)
|
||||
|
||||
Server 模式下,所有配置项支持环境变量,使用 `NEX_` 前缀:
|
||||
|
||||
```bash
|
||||
export NEX_SERVER_PORT=9000
|
||||
export NEX_DATABASE_PATH=/data/nex.db
|
||||
export NEX_LOG_LEVEL=debug
|
||||
|
||||
# MySQL 模式
|
||||
export NEX_DATABASE_DRIVER=mysql
|
||||
export NEX_DATABASE_HOST=db.example.com
|
||||
export NEX_DATABASE_PORT=3306
|
||||
export NEX_DATABASE_USER=nex
|
||||
export NEX_DATABASE_PASSWORD=secret
|
||||
export NEX_DATABASE_DBNAME=nex
|
||||
```
|
||||
|
||||
命名规则:配置路径转大写 + 下划线(如 `server.port` → `NEX_SERVER_PORT`)。
|
||||
|
||||
**Desktop 模式不支持环境变量覆盖。**Desktop 仅从 `~/.nex/config.yaml` 和默认值读取配置。
|
||||
|
||||
### CLI 参数(仅 Server 模式)
|
||||
|
||||
Server 模式下,支持命令行参数:
|
||||
|
||||
```bash
|
||||
./server --server-port 9000 --log-level debug --database-path /tmp/test.db
|
||||
```
|
||||
|
||||
命名规则:配置路径转 kebab-case(如 `server.port` → `--server-port`)。
|
||||
|
||||
**Desktop 不支持命令行参数覆盖配置。**Desktop 忽略所有 CLI 参数,仅从 `~/.nex/config.yaml` 读取。
|
||||
|
||||
### 数据文件
|
||||
|
||||
- `~/.nex/config.yaml` - 配置文件
|
||||
- `~/.nex/config.db` - SQLite 数据库
|
||||
- `~/.nex/config.db` - SQLite 数据库(MySQL 模式下不使用本地数据库文件)
|
||||
- `~/.nex/log/` - 日志目录
|
||||
|
||||
## 测试
|
||||
|
||||
### 后端测试
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
make test # 运行所有测试
|
||||
make test-coverage # 生成覆盖率报告
|
||||
# 全局默认测试(不含 MySQL 和前端 E2E)
|
||||
make test
|
||||
|
||||
# 产品级测试
|
||||
make server-test
|
||||
make desktop-test
|
||||
```
|
||||
|
||||
### 前端测试
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
bun run test # 单元测试 + 组件测试
|
||||
bun run test:watch # 监听模式
|
||||
bun run test:coverage # 生成覆盖率报告
|
||||
bun run test:e2e # E2E 测试
|
||||
```
|
||||
backend 分类测试、MySQL 专项测试和前端 E2E 测试请分别查看 `backend/README.md` 与 `frontend/README.md`。
|
||||
|
||||
## 开发
|
||||
|
||||
### 后端开发
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
make build # 构建
|
||||
make lint # 代码检查
|
||||
make migrate-up # 数据库迁移
|
||||
# 首次克隆后安装 Git hooks
|
||||
make hooks-install
|
||||
|
||||
# 检查 Git hooks 安装状态
|
||||
make hooks-check
|
||||
|
||||
# 运行 Git hooks 回归测试
|
||||
make hooks-test
|
||||
|
||||
# 全局命令
|
||||
make lint # 前后端共享检查
|
||||
make test # 默认全量测试(不含 MySQL/E2E)
|
||||
make clean # 清理所有构建产物和测试报告
|
||||
|
||||
# server 模式
|
||||
make server-run # 并行启动后端和前端开发服务
|
||||
make server-build # 构建 backend/bin/server 和 frontend/dist
|
||||
make server-lint # server 模式检查
|
||||
make server-test # server 模式测试
|
||||
make server-clean # 清理 server 模式产物
|
||||
|
||||
# desktop 模式
|
||||
make desktop-build-mac # 构建 macOS 桌面应用
|
||||
make desktop-build-win # 构建 Windows 桌面应用
|
||||
make desktop-build-linux # 构建 Linux 桌面应用
|
||||
make desktop-lint # desktop 模式检查
|
||||
make desktop-test # desktop 专属测试
|
||||
make desktop-clean # 清理 desktop 产物
|
||||
```
|
||||
|
||||
### 前端开发
|
||||
Git hooks 使用仓库内 `scripts/git-hooks/` 的原生脚本,不依赖额外 hook 框架。当前 hooks 包含:
|
||||
|
||||
- pre-commit:检查 staged files 的冲突标记、大文件告警和 LFS 指针,并按文件类型委托后端、versionctl、前端检查
|
||||
- prepare-commit-msg:在编辑器打开时提供提交信息模板,辅助填写 `类型: 简短描述` 和多行说明
|
||||
- commit-msg:校验提交信息格式为 `类型: 简短描述`,多行说明需在首行后保留空行;提交描述按项目规范使用中文,hook 不做字符集检测
|
||||
|
||||
## 版本与发布
|
||||
|
||||
### 统一版本源
|
||||
|
||||
- 仓库根目录 `VERSION` 是全仓唯一版本源,格式固定为 `x.y.z`
|
||||
- `frontend/package.json` 和前端 `.env.*` 中的 `VITE_APP_VERSION` 由仓库工具同步,不能手工漂移
|
||||
|
||||
### 本地版本演进
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
bun run build # 构建生产版本
|
||||
bun run lint # 代码检查
|
||||
# 递增版本(自动 lint + test + 工作区检查 + sync/check + commit + tag)
|
||||
make version-bump BUMP=minor
|
||||
|
||||
# 或指定具体版本号
|
||||
make version-bump SET_VERSION=1.0.0
|
||||
|
||||
# 推送到远程
|
||||
git push --follow-tags
|
||||
```
|
||||
|
||||
手动同步和校验:
|
||||
|
||||
```bash
|
||||
make version-sync
|
||||
make version-check
|
||||
```
|
||||
|
||||
### 本地生成发布资产
|
||||
|
||||
```bash
|
||||
# Linux: server + desktop
|
||||
make release-assets-linux
|
||||
|
||||
# Windows: server + desktop(需在 Windows 环境执行)
|
||||
make release-assets-windows
|
||||
|
||||
# macOS: darwin-amd64 server、darwin-arm64 server、desktop universal
|
||||
make release-assets-macos
|
||||
```
|
||||
|
||||
生成的版本化发布资产位于 `build/release/`。
|
||||
|
||||
### GitHub Draft Release
|
||||
|
||||
- 推送 `vX.Y.Z` tag 后,`.github/workflows/release.yml` 会自动执行发布流水线
|
||||
- 流水线会先校验 tag 与 `VERSION` 一致,再执行全流程测试门禁(lint、默认测试、MySQL 测试、E2E 测试),测试不通过则阻止构建
|
||||
- 测试通过后,三个平台 job 并行构建,各 job 会在正式构建前先检查 `go`、`bun` 和各自的平台打包工具链,缺失时快速失败并在日志中输出诊断信息
|
||||
- Windows 发布 job 在 `MSYS2 / MINGW64` shell 中执行,并继承 `setup-go` / `setup-bun` 准备好的工具链路径
|
||||
- 构建以下资产并上传到 GitHub Draft Release:
|
||||
- Linux server
|
||||
- Windows server
|
||||
- darwin-amd64 server
|
||||
- darwin-arm64 server
|
||||
- Linux desktop
|
||||
- Windows desktop
|
||||
- macOS desktop universal
|
||||
- Release 默认以 Draft 形式创建,需人工检查后再公开发布
|
||||
|
||||
## 开发规范
|
||||
|
||||
详见各子项目的 README.md:
|
||||
@@ -201,4 +450,4 @@ bun run lint # 代码检查
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT
|
||||
Apache License 2.0
|
||||
|
||||
BIN
assets/icon.icns
LFS
Normal file
BIN
assets/icon.icns
LFS
Normal file
Binary file not shown.
BIN
assets/icon.ico
LFS
Normal file
BIN
assets/icon.ico
LFS
Normal file
Binary file not shown.
BIN
assets/icon.png
LFS
Normal file
BIN
assets/icon.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/128x128/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/128x128/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/16x16/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/16x16/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/22x22/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/22x22/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/24x24/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/24x24/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/256x256/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/256x256/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/32x32/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/32x32/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/48x48/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/48x48/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/512x512/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/512x512/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/64x64/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/64x64/apps/nex.png
LFS
Normal file
Binary file not shown.
91
backend/.golangci.yml
Normal file
91
backend/.golangci.yml
Normal file
@@ -0,0 +1,91 @@
|
||||
run:
|
||||
timeout: 5m
|
||||
tests: true
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
- forbidigo
|
||||
- errorlint
|
||||
- errcheck
|
||||
- staticcheck
|
||||
- revive
|
||||
- gocritic
|
||||
- gosec
|
||||
- bodyclose
|
||||
- noctx
|
||||
- nilerr
|
||||
- goimports
|
||||
- gocyclo
|
||||
|
||||
linters-settings:
|
||||
errcheck:
|
||||
check-blank: true
|
||||
check-type-assertions: true
|
||||
exclude-functions:
|
||||
- fmt.Fprintf
|
||||
forbidigo:
|
||||
analyze-types: true
|
||||
forbid:
|
||||
- p: '^fmt\.Print.*$'
|
||||
msg: 使用 zap logger,不要直接输出到 stdout/stderr
|
||||
- p: '^fmt\.Fprint.*$'
|
||||
msg: 使用 zap logger,不要直接输出到 stdout/stderr
|
||||
- p: '^log\.(Print|Println|Printf|Fatal|Fatalln|Fatalf|Panic|Panicln|Panicf)$'
|
||||
msg: 使用 zap logger,不要使用标准库 log
|
||||
- p: '^zap\.L$'
|
||||
msg: 通过依赖注入传递 *zap.Logger,不要使用全局 logger
|
||||
- p: '^zap\.S$'
|
||||
msg: 不使用 Sugar logger
|
||||
revive:
|
||||
rules:
|
||||
- name: exported
|
||||
- name: var-naming
|
||||
- name: indent-error-flow
|
||||
- name: error-strings
|
||||
- name: error-return
|
||||
- name: blank-imports
|
||||
- name: context-as-argument
|
||||
- name: unexported-return
|
||||
goimports:
|
||||
local-prefixes: nex/backend
|
||||
gocyclo:
|
||||
min-complexity: 10
|
||||
|
||||
issues:
|
||||
exclude-dirs:
|
||||
- tests/mocks
|
||||
exclude-generated: true
|
||||
exclude-rules:
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- forbidigo
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- errcheck
|
||||
source: '(^\s*_\s*=|,\s*_)'
|
||||
- path: 'tests/integration/e2e_conversion_test\.go'
|
||||
linters:
|
||||
- errcheck
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- revive
|
||||
text: '^exported:'
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- gosec
|
||||
text: 'G(101|401|501)'
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- gocyclo
|
||||
text: 'cyclomatic complexity (1[1-9]|20) of .* is high \(> 10\)'
|
||||
- linters:
|
||||
- revive
|
||||
text: '(that stutters|BuildUrl should be BuildURL|ConvertHttpRequest should be ConvertHTTPRequest|ConvertHttpResponse should be ConvertHTTPResponse)'
|
||||
- path: 'internal/conversion/.*\.go'
|
||||
linters:
|
||||
- gocyclo
|
||||
- gocritic
|
||||
- path: '(internal/provider/client\.go|internal/service/model_service_impl\.go|internal/service/stats_buffer\.go|internal/handler/proxy_handler\.go|cmd/(desktop|server)/main\.go)'
|
||||
linters:
|
||||
- gocyclo
|
||||
@@ -1,45 +1,97 @@
|
||||
.PHONY: build run test test-coverage clean migrate-up migrate-down migrate-status migrate-create lint
|
||||
.PHONY: \
|
||||
build run \
|
||||
test test-unit test-integration test-coverage \
|
||||
lint clean \
|
||||
migrate-up migrate-down migrate-status migrate-create \
|
||||
mysql-up mysql-down mysql-test mysql-test-quick
|
||||
|
||||
VERSION := $(shell go run ../versionctl print)
|
||||
GIT_COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || printf 'unknown')
|
||||
BUILD_TIME ?= $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
|
||||
GO_LDFLAGS := -X nex/backend/pkg/buildinfo.version=$(VERSION) -X nex/backend/pkg/buildinfo.commit=$(GIT_COMMIT) -X nex/backend/pkg/buildinfo.buildTime=$(BUILD_TIME)
|
||||
|
||||
DB_DRIVER ?= sqlite3
|
||||
DB_DSN ?= $(HOME)/.nex/config.db
|
||||
|
||||
ifeq ($(DB_DRIVER),mysql)
|
||||
GOOSE_DIR := migrations/mysql
|
||||
GOOSE_DRIVER := mysql
|
||||
else ifeq ($(DB_DRIVER),sqlite3)
|
||||
GOOSE_DIR := migrations/sqlite
|
||||
GOOSE_DRIVER := sqlite3
|
||||
else
|
||||
$(error unsupported DB_DRIVER '$(DB_DRIVER)', use sqlite3 or mysql)
|
||||
endif
|
||||
|
||||
# 构建
|
||||
build:
|
||||
go build -o bin/server ./cmd/server
|
||||
go build -ldflags "$(GO_LDFLAGS)" -o bin/server ./cmd/server
|
||||
|
||||
# 运行
|
||||
run:
|
||||
go run ./cmd/server
|
||||
go run -ldflags "$(GO_LDFLAGS)" ./cmd/server
|
||||
|
||||
# 测试
|
||||
test:
|
||||
go test ./... -v
|
||||
go test ./internal/... ./pkg/... ./tests/... ./cmd/server/... -v
|
||||
|
||||
test-unit:
|
||||
go test ./internal/... ./pkg/... -v
|
||||
|
||||
test-integration:
|
||||
go test ./tests/... -v
|
||||
|
||||
# 测试覆盖率
|
||||
test-coverage:
|
||||
go test ./... -coverprofile=coverage.out
|
||||
go tool cover -html=coverage.out -o coverage.html
|
||||
@echo "Coverage report generated: coverage.html"
|
||||
@printf 'Coverage report generated: backend/coverage.html\n'
|
||||
|
||||
lint:
|
||||
go tool golangci-lint run ./...
|
||||
|
||||
# 清理
|
||||
clean:
|
||||
rm -rf bin/ coverage.out coverage.html
|
||||
|
||||
# 数据库迁移
|
||||
migrate-up:
|
||||
goose -dir migrations sqlite3 $(DB_PATH) up
|
||||
@printf 'Running database migration up...\n'
|
||||
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" up
|
||||
|
||||
migrate-down:
|
||||
goose -dir migrations sqlite3 $(DB_PATH) down
|
||||
@printf 'Running database migration down...\n'
|
||||
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" down
|
||||
|
||||
migrate-status:
|
||||
goose -dir migrations sqlite3 $(DB_PATH) status
|
||||
@printf 'Checking database migration status...\n'
|
||||
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" status
|
||||
|
||||
migrate-create:
|
||||
@read -p "Migration name: " name; \
|
||||
goose -dir migrations create $$name sql
|
||||
@printf 'Migration name: '; \
|
||||
read name; \
|
||||
goose -dir migrations/sqlite create $$name sql; \
|
||||
goose -dir migrations/mysql create $$name sql
|
||||
|
||||
# 代码检查
|
||||
lint:
|
||||
golangci-lint run ./...
|
||||
mysql-up:
|
||||
@printf 'Starting MySQL test container...\n'
|
||||
cd tests/mysql && docker-compose up -d
|
||||
@printf 'Waiting for MySQL to be ready...\n'
|
||||
@for i in $$(seq 1 30); do \
|
||||
if docker exec nex-mysql-test mysqladmin ping -h localhost -u root -ptestpass --silent 2>/dev/null; then \
|
||||
printf 'MySQL is ready\n'; \
|
||||
exit 0; \
|
||||
fi; \
|
||||
printf 'Waiting... (%s/30)\n' $$i; \
|
||||
sleep 1; \
|
||||
done; \
|
||||
printf 'MySQL failed to start\n'; \
|
||||
exit 1
|
||||
|
||||
# 安装依赖
|
||||
deps:
|
||||
go mod tidy
|
||||
mysql-down:
|
||||
@printf 'Stopping MySQL test container...\n'
|
||||
cd tests/mysql && docker-compose down -v
|
||||
|
||||
mysql-test:
|
||||
@set -e; \
|
||||
$(MAKE) mysql-up; \
|
||||
trap '$(MAKE) mysql-down' EXIT; \
|
||||
go test -tags=mysql ./tests/mysql/... -v -count=1
|
||||
|
||||
mysql-test-quick:
|
||||
@printf 'Running MySQL tests without container management...\n'
|
||||
go test -tags=mysql ./tests/mysql/... -v -count=1
|
||||
|
||||
@@ -4,29 +4,75 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 支持 OpenAI 协议(`/openai/v1/...`)
|
||||
- 支持 OpenAI 协议(`/openai/v1/...`,例如 `/openai/v1/chat/completions`)
|
||||
- 支持 Anthropic 协议(`/anthropic/v1/...`)
|
||||
- 支持 Hub-and-Spoke 跨协议双向转换(OpenAI ↔ Anthropic)
|
||||
- 同协议透传(零语义损失、零序列化开销)
|
||||
- 同协议透传(跳过 Canonical 全量转换,保持协议语义)
|
||||
- 支持流式响应(SSE)
|
||||
- 支持 Function Calling / Tools
|
||||
- 支持 Thinking / Reasoning
|
||||
- 支持扩展层接口(Models、Embeddings、Rerank)
|
||||
- 多供应商配置和路由
|
||||
- 用量统计
|
||||
- 结构化日志(zap + lumberjack)
|
||||
- 结构化日志(zap + lumberjack + 模块标识)
|
||||
- YAML 配置管理
|
||||
- 请求验证
|
||||
- 中间件支持(请求 ID、日志、恢复、CORS)
|
||||
|
||||
## 日志规范
|
||||
|
||||
### 模块标识
|
||||
|
||||
每个模块通过依赖注入获取带模块标识的 logger:
|
||||
|
||||
```go
|
||||
func NewProxyHandler(..., logger *zap.Logger) *ProxyHandler {
|
||||
return &ProxyHandler{
|
||||
logger: pkglogger.WithModule(logger, "handler.proxy"),
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
输出格式:
|
||||
- Console: `INFO [handler.proxy] 处理请求 method=POST path=/v1/chat`
|
||||
- JSON: `{"level":"info","logger":"handler.proxy","msg":"处理请求"}`
|
||||
|
||||
### 模块命名规范
|
||||
|
||||
| 模块 | 命名 |
|
||||
|------|------|
|
||||
| ProxyHandler | `handler.proxy` |
|
||||
| ProviderHandler | `handler.provider` |
|
||||
| Provider Client | `provider.client` |
|
||||
| ConversionEngine | `conversion.engine` |
|
||||
| RoutingCache | `service.routing_cache` |
|
||||
| StatsBuffer | `service.stats_buffer` |
|
||||
| Database | `database` |
|
||||
|
||||
### 标准字段
|
||||
|
||||
使用 `pkg/logger/field.go` 中定义的字段构造函数:
|
||||
|
||||
```go
|
||||
logger.Debug("请求开始",
|
||||
pkglogger.Method("POST"),
|
||||
pkglogger.Path("/v1/chat"),
|
||||
pkglogger.RequestID("xxx"),
|
||||
)
|
||||
```
|
||||
|
||||
### GORM 日志
|
||||
|
||||
GORM 日志自动桥接到 zap,SQL 查询映射到 Debug 级别。
|
||||
|
||||
## 技术栈
|
||||
|
||||
- **语言**: Go 1.26+
|
||||
- **HTTP 框架**: Gin
|
||||
- **ORM**: GORM
|
||||
- **数据库**: SQLite
|
||||
- **数据库**: SQLite / MySQL
|
||||
- **日志**: zap + lumberjack
|
||||
- **配置**: Viper + pflag(多层配置:CLI > 环境变量 > 配置文件 > 默认值)
|
||||
- **配置**: Viper + pflag(Server 多层配置,Desktop 配置文件快照)
|
||||
- **验证**: go-playground/validator/v10
|
||||
- **迁移**: goose
|
||||
|
||||
@@ -105,18 +151,30 @@ backend/
|
||||
│ │ ├── errors.go
|
||||
│ │ └── wrap.go
|
||||
│ ├── logger/ # 日志系统
|
||||
│ │ ├── logger.go
|
||||
│ │ ├── rotate.go
|
||||
│ │ └── context.go
|
||||
│ │ ├── logger.go # 核心初始化
|
||||
│ │ ├── field.go # 标准字段定义
|
||||
│ │ ├── module.go # 模块日志器
|
||||
│ │ ├── context.go # Context 辅助函数
|
||||
│ │ ├── gorm.go # GORM 适配器
|
||||
│ │ ├── minimal.go # 最小化 logger
|
||||
│ │ └── rotate.go # 日志轮转
|
||||
│ ├── modelid/ # 统一模型 ID 工具包
|
||||
│ │ ├── model_id.go
|
||||
│ │ └── model_id_test.go
|
||||
│ └── validator/ # 验证器
|
||||
│ └── validator.go
|
||||
├── migrations/ # 数据库迁移
|
||||
│ ├── 20260401000001_initial_schema.sql
|
||||
│ ├── 20260401000002_add_indexes.sql
|
||||
│ └── 20260419000001_add_provider_protocol.sql
|
||||
│ ├── embed.go # go:embed 迁移资源入口
|
||||
│ ├── sqlite/
|
||||
│ │ └── 20260421000001_initial_schema.sql
|
||||
│ └── mysql/
|
||||
│ └── 20260421000001_initial_schema.sql
|
||||
├── tests/ # 集成测试
|
||||
│ ├── helpers.go
|
||||
│ └── integration/
|
||||
│ ├── helpers.go # 测试辅助函数
|
||||
│ ├── config/ # 测试配置
|
||||
│ ├── integration/ # 集成测试
|
||||
│ │ └── e2e_conversion_test.go # E2E 协议转换测试
|
||||
│ └── mocks/ # Mock 实现
|
||||
├── Makefile
|
||||
├── go.mod
|
||||
└── README.md
|
||||
@@ -145,6 +203,123 @@ Client Request (clientProtocol)
|
||||
|
||||
同协议时自动透传,跳过序列化开销。
|
||||
|
||||
## 协议转换架构
|
||||
|
||||
### Canonical Model 中间表示
|
||||
|
||||
所有协议转换都经过 Canonical Model 中间表示层,实现 Hub-and-Spoke 架构:
|
||||
|
||||
```
|
||||
OpenAI Request → Canonical Request → Anthropic Request
|
||||
(中间表示)
|
||||
OpenAI Response ← Canonical Response ← Anthropic Response
|
||||
```
|
||||
|
||||
**CanonicalRequest 核心字段**:
|
||||
- `Model` - 统一模型 ID
|
||||
- `Messages` - 消息列表(支持 text、tool_use、tool_result、thinking 类型)
|
||||
- `Tools` - 工具定义
|
||||
- `Thinking` - 推理配置(`budget_tokens`、`effort`)
|
||||
- `Parameters` - 通用参数(`max_tokens`、`temperature`、`top_p` 等)
|
||||
|
||||
### Smart Passthrough 机制
|
||||
|
||||
同协议请求走 Smart Passthrough 路径,不进入 Canonical 全量转换:
|
||||
|
||||
```
|
||||
1. 检测 clientProtocol == providerProtocol
|
||||
2. 仅改写请求体中的 model 字段:unified_id → upstream_model_name
|
||||
3. 直接转发请求到上游
|
||||
4. 响应中仅改写 model 字段:upstream_model_name → unified_id
|
||||
```
|
||||
|
||||
Smart Passthrough 保持未改写 JSON 字段的内容和类型不变,不承诺保留原始字节顺序、空白或对象字段顺序。
|
||||
|
||||
### 流式转换器层次
|
||||
|
||||
```
|
||||
StreamConverter (接口)
|
||||
├── PassthroughStreamConverter # 直接透传,无任何处理
|
||||
├── SmartPassthroughStreamConverter # 同协议 + 按 SSE frame 改写 data JSON model
|
||||
└── CanonicalStreamConverter # 跨协议完整转换(decode → encode)
|
||||
```
|
||||
|
||||
### InterfaceType 枚举
|
||||
|
||||
| 类型 | 说明 |
|
||||
|------|------|
|
||||
| `CHAT` | 对话补全(chat/completions、messages) |
|
||||
| `MODELS` | 模型列表 |
|
||||
| `MODEL_INFO` | 模型详情 |
|
||||
| `EMBEDDINGS` | 嵌入接口 |
|
||||
| `RERANK` | 重排序接口 |
|
||||
| `PASSTHROUGH` | 未知接口,直接透传 |
|
||||
|
||||
## 协议适配器特性
|
||||
|
||||
### OpenAI 适配器
|
||||
|
||||
**特有字段支持**:
|
||||
- `reasoning_effort` - 映射到 Canonical Thinking 配置(`none` → 禁用,其他 → `effort`)
|
||||
- `reasoning_content` - 非标准字段,映射到 Canonical thinking 块
|
||||
- `max_completion_tokens` - 新字段,优先于 `max_tokens`
|
||||
- `refusal` - 非标准字段,作为 text 块处理
|
||||
|
||||
**废弃字段兼容**:
|
||||
- `functions` / `function_call` - 自动转换为 `tools` / `tool_choice`
|
||||
|
||||
**消息处理**:
|
||||
- 合并连续同角色消息(Anthropic 不支持连续同角色)
|
||||
- 工具选择映射:`any` → `required`
|
||||
|
||||
### Anthropic 适配器
|
||||
|
||||
**特有字段支持**:
|
||||
- `thinking` - 推理配置(`type: enabled`、`budget_tokens`、`effort`)
|
||||
- `output_config` - 结构化输出配置
|
||||
- `disable_parallel_tool_use` - 禁用并行工具调用
|
||||
- `container` - 工具容器字段
|
||||
|
||||
**不支持的功能**:
|
||||
- Embeddings 接口(返回 `INTERFACE_NOT_SUPPORTED` 错误)
|
||||
|
||||
### 跨协议转换注意事项
|
||||
|
||||
| 源协议 | 目标协议 | 转换说明 |
|
||||
|--------|----------|----------|
|
||||
| OpenAI | Anthropic | `reasoning_effort` → `thinking`,消息角色合并 |
|
||||
| Anthropic | OpenAI | `thinking` 块 → `reasoning_content`,工具选择转换 |
|
||||
|
||||
## 错误码
|
||||
|
||||
### ConversionError 错误码
|
||||
|
||||
| 错误码 | 说明 |
|
||||
|--------|------|
|
||||
| `INVALID_INPUT` | 输入数据无效 |
|
||||
| `MISSING_REQUIRED_FIELD` | 缺少必填字段 |
|
||||
| `INCOMPATIBLE_FEATURE` | 功能不兼容(如跨协议不支持某特性) |
|
||||
| `FIELD_MAPPING_FAILURE` | 字段映射失败 |
|
||||
| `TOOL_CALL_PARSE_ERROR` | 工具调用解析错误 |
|
||||
| `JSON_PARSE_ERROR` | JSON 解析错误 |
|
||||
| `STREAM_STATE_ERROR` | 流式状态错误 |
|
||||
| `UTF8_DECODE_ERROR` | UTF-8 解码错误(流式 chunk 截断) |
|
||||
| `PROTOCOL_CONSTRAINT_VIOLATION` | 协议约束违反 |
|
||||
| `ENCODING_FAILURE` | 编码失败 |
|
||||
| `INTERFACE_NOT_SUPPORTED` | 接口不支持(如 Anthropic Embeddings) |
|
||||
| `UNSUPPORTED_MULTIMODAL` | 跨协议暂不支持多模态内容 |
|
||||
|
||||
### AppError 预定义错误
|
||||
|
||||
| 错误 | HTTP 状态码 | 说明 |
|
||||
|------|-------------|------|
|
||||
| `ErrModelNotFound` | 404 | 模型未找到 |
|
||||
| `ErrModelDisabled` | 404 | 模型已禁用 |
|
||||
| `ErrProviderNotFound` | 404 | 供应商未找到 |
|
||||
| `ErrInvalidProviderID` | 400 | 供应商 ID 格式无效 |
|
||||
| `ErrDuplicateModel` | 409 | 同一供应商下模型名称重复 |
|
||||
| `ErrImmutableField` | 400 | 不可修改字段(如供应商 ID) |
|
||||
|
||||
## 运行方式
|
||||
|
||||
### 安装依赖
|
||||
@@ -159,15 +334,18 @@ go mod download
|
||||
go run cmd/server/main.go
|
||||
```
|
||||
|
||||
服务将在端口 9826 启动。首次启动会自动创建配置文件和运行数据库迁移。
|
||||
服务将在端口 9826 启动。首次启动会自动运行数据库迁移。
|
||||
|
||||
## 配置
|
||||
|
||||
配置支持多种方式:配置文件、环境变量、命令行参数,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值**
|
||||
配置方式取决于启动入口:
|
||||
|
||||
- **Server 入口**(`cmd/server`):支持 CLI 参数 > 环境变量 > 配置文件 > 默认值
|
||||
- **Desktop 入口**(`cmd/desktop`):仅支持 `~/.nex/config.yaml` > 默认值,不支持 CLI 参数和 `NEX_*` 环境变量覆盖,修改配置文件后需重启生效
|
||||
|
||||
### 配置文件
|
||||
|
||||
配置文件位于 `~/.nex/config.yaml`,首次启动自动生成。
|
||||
配置文件位于 `~/.nex/config.yaml`。配置文件不存在时使用默认值,不会自动生成;需要自定义时手动创建该文件:
|
||||
|
||||
```yaml
|
||||
server:
|
||||
@@ -176,7 +354,14 @@ server:
|
||||
write_timeout: 30s
|
||||
|
||||
database:
|
||||
path: ~/.nex/config.db
|
||||
driver: sqlite # sqlite 或 mysql
|
||||
path: ~/.nex/config.db # SQLite 数据库文件路径
|
||||
# --- MySQL 配置(driver=mysql 时生效)---
|
||||
# host: localhost
|
||||
# port: 3306
|
||||
# user: nex
|
||||
# password: ""
|
||||
# dbname: nex
|
||||
max_idle_conns: 10
|
||||
max_open_conns: 100
|
||||
conn_max_lifetime: 1h
|
||||
@@ -190,19 +375,27 @@ log:
|
||||
compress: true
|
||||
```
|
||||
|
||||
### 环境变量
|
||||
### 环境变量(仅 Server 入口)
|
||||
|
||||
所有配置项都支持环境变量,使用 `NEX_` 前缀:
|
||||
Server 入口下,所有配置项都支持环境变量,使用 `NEX_` 前缀:
|
||||
|
||||
```bash
|
||||
export NEX_SERVER_PORT=9000
|
||||
export NEX_DATABASE_PATH=/data/nex.db
|
||||
export NEX_LOG_LEVEL=debug
|
||||
|
||||
# MySQL 模式
|
||||
export NEX_DATABASE_DRIVER=mysql
|
||||
export NEX_DATABASE_HOST=db.example.com
|
||||
export NEX_DATABASE_PORT=3306
|
||||
export NEX_DATABASE_USER=nex
|
||||
export NEX_DATABASE_PASSWORD=secret
|
||||
export NEX_DATABASE_DBNAME=nex
|
||||
```
|
||||
|
||||
命名规则:配置路径转大写 + 下划线 + `NEX_` 前缀(如 `server.port` → `NEX_SERVER_PORT`)。
|
||||
|
||||
### 命令行参数
|
||||
### 命令行参数(仅 Server 入口)
|
||||
|
||||
```bash
|
||||
./server --server-port 9000 --log-level debug --database-path /tmp/test.db
|
||||
@@ -214,7 +407,7 @@ export NEX_LOG_LEVEL=debug
|
||||
|
||||
```
|
||||
服务器: --server-port, --server-read-timeout, --server-write-timeout
|
||||
数据库: --database-path, --database-max-idle-conns, --database-max-open-conns, --database-conn-max-lifetime
|
||||
数据库: --database-driver, --database-path, --database-host, --database-port, --database-user, --database-password, --database-dbname, --database-max-idle-conns, --database-max-open-conns, --database-conn-max-lifetime
|
||||
日志: --log-level, --log-path, --log-max-size, --log-max-backups, --log-max-age, --log-compress
|
||||
通用: --config (指定配置文件路径)
|
||||
```
|
||||
@@ -234,36 +427,56 @@ export NEX_LOG_LEVEL=debug
|
||||
# Docker 部署
|
||||
docker run -d -e NEX_SERVER_PORT=9000 -e NEX_LOG_LEVEL=info nex-server
|
||||
|
||||
# MySQL 模式
|
||||
./server --database-driver mysql --database-host db.example.com --database-user nex --database-password secret --database-dbname nex
|
||||
|
||||
# 自定义配置文件
|
||||
./server --config /path/to/custom.yaml
|
||||
```
|
||||
|
||||
数据文件:
|
||||
- `~/.nex/config.yaml` - 配置文件
|
||||
- `~/.nex/config.db` - SQLite 数据库
|
||||
- `~/.nex/config.db` - SQLite 数据库(MySQL 模式下不使用本地数据库文件)
|
||||
- `~/.nex/log/` - 日志目录
|
||||
|
||||
**MySQL 连接说明**:MySQL 连接使用 DSN 格式: `user:password@tcp(host:port)/dbname?charset=utf8mb4&parseTime=true&loc=Local`,最低支持 MySQL 8.0+。
|
||||
|
||||
## 测试
|
||||
|
||||
```bash
|
||||
# 运行所有测试
|
||||
# 运行 backend 默认测试
|
||||
make test
|
||||
|
||||
# 分类测试
|
||||
make test-unit
|
||||
make test-integration
|
||||
|
||||
# 生成覆盖率报告
|
||||
make test-coverage
|
||||
|
||||
# MySQL 专项测试
|
||||
make mysql-up
|
||||
make mysql-down
|
||||
make mysql-test
|
||||
make mysql-test-quick
|
||||
```
|
||||
|
||||
## 数据库迁移
|
||||
|
||||
应用启动时使用随二进制打包的迁移资源(`go:embed`)自动执行迁移,server 和 desktop 发布产物均自包含,不依赖源码目录。开发时可继续通过 Makefile goose CLI 操作文件系统中的 `migrations/<dialect>/` 目录,运行时嵌入资源与文件系统目录共享同一批 SQL 文件。
|
||||
|
||||
```bash
|
||||
# 使用 Makefile
|
||||
make migrate-up DB_PATH=~/.nex/config.db
|
||||
make migrate-down DB_PATH=~/.nex/config.db
|
||||
make migrate-status DB_PATH=~/.nex/config.db
|
||||
make migrate-up DB_DSN=~/.nex/config.db
|
||||
make migrate-down DB_DSN=~/.nex/config.db
|
||||
make migrate-status DB_DSN=~/.nex/config.db
|
||||
|
||||
# 创建新迁移
|
||||
make migrate-create
|
||||
|
||||
# MySQL 迁移
|
||||
make migrate-up DB_DRIVER=mysql DB_DSN='user:pass@tcp(localhost:3306)/nex'
|
||||
|
||||
# 或直接使用 goose
|
||||
goose -dir migrations sqlite3 ~/.nex/config.db up
|
||||
```
|
||||
@@ -272,7 +485,7 @@ goose -dir migrations sqlite3 ~/.nex/config.db up
|
||||
|
||||
### 代理接口
|
||||
|
||||
使用 `/{protocol}/v1/{path}` URL 前缀路由:
|
||||
使用 `/{protocol}/*path` URL 前缀路由。网关只剥离第一段协议前缀,不在 Handler 中统一添加或移除 `/v1`;剩余 path 是协议原生 nativePath,由对应 adapter 识别和组合上游 URL。
|
||||
|
||||
#### OpenAI 协议
|
||||
|
||||
@@ -290,7 +503,19 @@ POST /anthropic/v1/messages
|
||||
GET /anthropic/v1/models
|
||||
```
|
||||
|
||||
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销。
|
||||
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传或 Smart Passthrough,跳过 Canonical 全量转换。
|
||||
|
||||
**统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。
|
||||
|
||||
**base_url 约定**:
|
||||
- OpenAI 供应商配置到版本路径一级,例如 `https://api.openai.com/v1`;当客户端请求 `/openai/v1/chat/completions` 时,OpenAI adapter 会把 nativePath `/v1/chat/completions` 映射为上游 path `/chat/completions`,最终 URL 为 `https://api.openai.com/v1/chat/completions`。
|
||||
- Anthropic 供应商配置到域名级,例如 `https://api.anthropic.com`。
|
||||
|
||||
**模型提取边界**:只有 adapter 明确适配的 Chat、Embeddings、Rerank 等接口会提取 `model` 并尝试统一模型 ID 路由。未知接口不做顶层 `model` 猜测,直接按无 model 透传。
|
||||
|
||||
**流式透传边界**:同协议无响应 model 改写时 raw passthrough,保留 SSE frame 边界和 `[DONE]`;同协议需要改写时按 SSE frame 解析 `data` JSON,仅改写 `model`;跨协议继续使用 StreamDecoder → CanonicalStreamConverter → StreamEncoder。
|
||||
|
||||
**错误边界**:网关层代理错误返回 `{"error":"...","code":"..."}`。已收到上游 HTTP 响应时,非 2xx status、过滤 hop-by-hop header 后的 headers 和 body 直接透传;没有收到上游响应的连接/DNS/TLS/超时错误返回 `UPSTREAM_UNAVAILABLE`。
|
||||
|
||||
### 管理接口
|
||||
|
||||
@@ -314,7 +539,13 @@ GET /anthropic/v1/models
|
||||
|
||||
**Protocol 字段**:标识上游供应商使用的协议类型,可选值 `"openai"`(默认)、`"anthropic"`。
|
||||
|
||||
**base_url 说明**:应配置到 API 版本路径,不包含具体端点(如 OpenAI: `https://api.openai.com/v1`,GLM: `https://open.bigmodel.cn/api/paas/v4`)。
|
||||
**base_url 说明**:
|
||||
- OpenAI 协议:配置到 API 版本路径,如 `https://api.openai.com/v1`、`https://open.bigmodel.cn/api/paas/v4`
|
||||
- Anthropic 协议:配置到域名,不包含版本路径,如 `https://api.anthropic.com`
|
||||
|
||||
**对外 URL 格式**:
|
||||
- OpenAI 协议:`/{protocol}/v1/{endpoint}`,如 `/openai/v1/chat/completions`、`/openai/v1/models`、`/openai/v1/embeddings`
|
||||
- Anthropic 协议:`/{protocol}/v1/{endpoint}`,如 `/anthropic/v1/messages`、`/anthropic/v1/models`
|
||||
|
||||
#### 模型管理
|
||||
|
||||
@@ -324,14 +555,30 @@ GET /anthropic/v1/models
|
||||
- `PUT /api/models/:id` - 更新模型
|
||||
- `DELETE /api/models/:id` - 删除模型
|
||||
|
||||
**创建请求**(id 由系统自动生成 UUID):
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "gpt-4",
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4"
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
"unified_id": "openai/gpt-4",
|
||||
"enabled": true,
|
||||
"created_at": "2026-04-21T00:00:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
**统一模型 ID**:`unified_id` 字段为 `provider_id/model_name` 格式,用于代理请求的 `model` 参数。
|
||||
|
||||
#### 统计查询
|
||||
|
||||
- `GET /api/stats` - 查询统计
|
||||
@@ -339,6 +586,20 @@ GET /anthropic/v1/models
|
||||
|
||||
查询参数:`provider_id`、`model_name`、`start_date`(YYYY-MM-DD)、`end_date`、`group_by`(provider/model/date)
|
||||
|
||||
#### 版本信息
|
||||
|
||||
- `GET /api/version` - 获取后端构建版本信息
|
||||
|
||||
响应字段来源于构建阶段注入的 `buildinfo` 元数据:
|
||||
|
||||
```json
|
||||
{
|
||||
"version": "0.1.0",
|
||||
"commit": "abc1234",
|
||||
"build_time": "2026-05-05T00:00:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
#### 健康检查
|
||||
|
||||
- `GET /health` - 返回 `{"status": "ok"}`
|
||||
@@ -346,9 +607,12 @@ GET /anthropic/v1/models
|
||||
## 开发
|
||||
|
||||
```bash
|
||||
make build # 构建
|
||||
make lint # 代码检查
|
||||
make deps # 整理依赖
|
||||
make build # 构建 backend/bin/server
|
||||
make run # 运行后端服务
|
||||
make lint # 代码检查
|
||||
make clean # 清理 backend 构建产物
|
||||
go mod tidy # 整理依赖
|
||||
go generate ./... # 刷新 mock 等生成代码
|
||||
```
|
||||
|
||||
环境要求:Go 1.26 或更高版本
|
||||
@@ -397,6 +661,7 @@ err := v.Validate(myStruct)
|
||||
|
||||
- **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节
|
||||
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接
|
||||
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配
|
||||
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配(lint 强约束:errorlint 禁止 `err ==` 直接比较和 `err.(*T)` 直接断言)
|
||||
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
|
||||
- **日志 error 字段**:使用 `zap.Error(err)`,不使用 `zap.String("error", err.Error())` 手工字符串化
|
||||
- **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片
|
||||
|
||||
25
backend/cmd/desktop/dialog_darwin.go
Normal file
25
backend/cmd/desktop/dialog_darwin.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build darwin
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func showError(title, message string) {
|
||||
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`,
|
||||
escapeAppleScript(message), escapeAppleScript(title))
|
||||
if err := exec.Command("osascript", "-e", script).Run(); err != nil {
|
||||
dialogLogger().Warn("显示错误对话框失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func escapeAppleScript(s string) string {
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
return s
|
||||
}
|
||||
67
backend/cmd/desktop/dialog_linux.go
Normal file
67
backend/cmd/desktop/dialog_linux.go
Normal file
@@ -0,0 +1,67 @@
|
||||
//go:build linux
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type dialogToolType int
|
||||
|
||||
const (
|
||||
toolNone dialogToolType = iota
|
||||
toolZenity
|
||||
toolKdialog
|
||||
toolNotifySend
|
||||
toolXmessage
|
||||
)
|
||||
|
||||
var (
|
||||
dialogTool dialogToolType
|
||||
dialogToolOnce sync.Once
|
||||
)
|
||||
|
||||
func init() {
|
||||
dialogToolOnce.Do(detectDialogTool)
|
||||
}
|
||||
|
||||
func detectDialogTool() {
|
||||
tools := []struct {
|
||||
name string
|
||||
typ dialogToolType
|
||||
}{
|
||||
{"zenity", toolZenity},
|
||||
{"kdialog", toolKdialog},
|
||||
{"notify-send", toolNotifySend},
|
||||
{"xmessage", toolXmessage},
|
||||
}
|
||||
|
||||
for _, tool := range tools {
|
||||
if _, err := exec.LookPath(tool.name); err == nil {
|
||||
dialogTool = tool.typ
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
dialogTool = toolNone
|
||||
}
|
||||
|
||||
func showError(title, message string) {
|
||||
switch dialogTool {
|
||||
case toolZenity:
|
||||
exec.Command("zenity", "--error",
|
||||
fmt.Sprintf("--title=%s", title),
|
||||
fmt.Sprintf("--text=%s", message)).Run()
|
||||
case toolKdialog:
|
||||
exec.Command("kdialog", "--error", message, "--title", title).Run()
|
||||
case toolNotifySend:
|
||||
exec.Command("notify-send", "-u", "critical", title, message).Run()
|
||||
case toolXmessage:
|
||||
exec.Command("xmessage", "-center",
|
||||
fmt.Sprintf("%s: %s", title, message)).Run()
|
||||
default:
|
||||
dialogLogger().Error("无法显示错误对话框")
|
||||
}
|
||||
}
|
||||
15
backend/cmd/desktop/dialog_logger.go
Normal file
15
backend/cmd/desktop/dialog_logger.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
func dialogLogger() *zap.Logger {
|
||||
if zapLogger != nil {
|
||||
return zapLogger
|
||||
}
|
||||
|
||||
return pkgLogger.NewMinimal()
|
||||
}
|
||||
62
backend/cmd/desktop/dialog_windows.go
Normal file
62
backend/cmd/desktop/dialog_windows.go
Normal file
@@ -0,0 +1,62 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
mbIconError = 0x10
|
||||
mbIconInformation = 0x40
|
||||
)
|
||||
|
||||
var (
|
||||
user32 = syscall.NewLazyDLL("user32.dll")
|
||||
procMessageBoxW = user32.NewProc("MessageBoxW")
|
||||
callMessageBoxW = func(hwnd, text, caption, flags uintptr) (uintptr, error) {
|
||||
ret, _, err := procMessageBoxW.Call(hwnd, text, caption, flags)
|
||||
return ret, err
|
||||
}
|
||||
)
|
||||
|
||||
func showError(title, message string) {
|
||||
if err := messageBox(title, message, mbIconError); err != nil {
|
||||
if zapLogger != nil {
|
||||
zapLogger.Warn("显示错误对话框失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func messageBox(title, message string, flags uint) error {
|
||||
titlePtr, err := syscall.UTF16PtrFromString(title)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
messagePtr, err := syscall.UTF16PtrFromString(message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ret, callErr := callMessageBoxW(
|
||||
0,
|
||||
uintptr(unsafe.Pointer(messagePtr)),
|
||||
uintptr(unsafe.Pointer(titlePtr)),
|
||||
uintptr(flags),
|
||||
)
|
||||
if ret != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if callErr != nil && !errors.Is(callErr, syscall.Errno(0)) {
|
||||
return callErr
|
||||
}
|
||||
|
||||
return fmt.Errorf("MessageBoxW 调用失败")
|
||||
}
|
||||
33
backend/cmd/desktop/icon_test.go
Normal file
33
backend/cmd/desktop/icon_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"nex/embedfs"
|
||||
)
|
||||
|
||||
func TestIconSelection_Windows(t *testing.T) {
|
||||
if runtime.GOOS != "windows" {
|
||||
t.Skip("图标格式选择测试仅在 Windows 上运行")
|
||||
}
|
||||
|
||||
if err := testIconLoad("assets/icon.ico"); err != nil {
|
||||
t.Fatalf("Windows 应加载 .ico 文件: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIconSelection_NonWindows(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("图标格式选择测试在非 Windows 平台运行")
|
||||
}
|
||||
|
||||
if err := testIconLoad("assets/icon.png"); err != nil {
|
||||
t.Fatalf("非 Windows 平台应加载 .png 文件: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testIconLoad(path string) error {
|
||||
_, err := embedfs.Assets.ReadFile(path)
|
||||
return err
|
||||
}
|
||||
1
backend/cmd/desktop/icon_windows.rc
Normal file
1
backend/cmd/desktop/icon_windows.rc
Normal file
@@ -0,0 +1 @@
|
||||
1 ICON "../../../assets/icon.ico"
|
||||
442
backend/cmd/desktop/main.go
Normal file
442
backend/cmd/desktop/main.go
Normal file
@@ -0,0 +1,442 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nex/embedfs"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/conversion/openai"
|
||||
"nex/backend/internal/database"
|
||||
"nex/backend/internal/handler"
|
||||
"nex/backend/internal/handler/middleware"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
"nex/backend/pkg/buildinfo"
|
||||
|
||||
"github.com/getlantern/systray"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/flock"
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
server *http.Server
|
||||
zapLogger *zap.Logger
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
)
|
||||
|
||||
func main() {
|
||||
minimalLogger := pkgLogger.NewMinimal()
|
||||
|
||||
cfg, cfgMeta, err := config.LoadDesktopConfigWithMetadata()
|
||||
if err != nil {
|
||||
minimalLogger.Error("加载配置失败", zap.Error(err))
|
||||
showError(appName, desktopConfigErrorMessage(getDesktopConfigPath(), err))
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
port := cfg.Server.Port
|
||||
|
||||
if err := checkPortAvailable(port); err != nil {
|
||||
minimalLogger.Error("端口不可用", zap.Error(err))
|
||||
showError(appName, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
singleLock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway.lock"))
|
||||
if err := singleLock.Lock(); err != nil {
|
||||
minimalLogger.Error("已有 Nex 实例运行")
|
||||
showError(appName, "已有 Nex 实例运行")
|
||||
os.Exit(1)
|
||||
}
|
||||
defer func() {
|
||||
if err := singleLock.Unlock(); err != nil {
|
||||
minimalLogger.Warn("释放实例锁失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
zapLogger, err = pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
|
||||
Level: cfg.Log.Level,
|
||||
Path: cfg.Log.Path,
|
||||
MaxSize: cfg.Log.MaxSize,
|
||||
MaxBackups: cfg.Log.MaxBackups,
|
||||
MaxAge: cfg.Log.MaxAge,
|
||||
Compress: cfg.Log.Compress,
|
||||
})
|
||||
if err != nil {
|
||||
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
|
||||
}
|
||||
defer func() {
|
||||
if err := zapLogger.Sync(); err != nil {
|
||||
minimalLogger.Warn("同步日志失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
cfg.PrintSummary(zapLogger)
|
||||
|
||||
db, err := database.Init(&cfg.Database, zapLogger)
|
||||
if err != nil {
|
||||
zapLogger.Fatal("初始化数据库失败", zap.Error(err))
|
||||
}
|
||||
defer database.Close(db)
|
||||
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
|
||||
if err := routingCache.Preload(); err != nil {
|
||||
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
|
||||
}
|
||||
|
||||
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
|
||||
service.WithFlushInterval(5*time.Second),
|
||||
service.WithFlushThreshold(100))
|
||||
statsBuffer.Start()
|
||||
defer statsBuffer.Stop()
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
||||
routingService := service.NewRoutingService(routingCache)
|
||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
|
||||
}
|
||||
if err := registry.Register(anthropic.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
|
||||
}
|
||||
engine := conversion.NewConversionEngine(registry, zapLogger)
|
||||
|
||||
providerClient := provider.NewClient(zapLogger)
|
||||
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
statsHandler := handler.NewStatsHandler(statsService)
|
||||
versionHandler := handler.NewVersionHandler()
|
||||
settingsHandler := handler.NewSettingsHandler(cfg, "desktop", true, cfgMeta.ConfigPath)
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.New()
|
||||
|
||||
r.Use(middleware.RequestID())
|
||||
r.Use(middleware.Recovery(zapLogger))
|
||||
r.Use(middleware.Logging(zapLogger))
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler, versionHandler, settingsHandler)
|
||||
setupStaticFiles(r)
|
||||
|
||||
server = &http.Server{
|
||||
Addr: desktopListenAddr(port),
|
||||
Handler: r,
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownCancel = context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
zapLogger.Info("AI Gateway 启动",
|
||||
zap.String("addr", server.Addr),
|
||||
zap.String("version", buildinfo.Version()),
|
||||
zap.String("commit", buildinfo.Commit()),
|
||||
zap.String("build_time", buildinfo.BuildTime()))
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
zapLogger.Fatal("服务器启动失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := openBrowser(desktopURL(port)); err != nil {
|
||||
zapLogger.Warn("无法打开浏览器", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
setupSystray(port)
|
||||
}
|
||||
|
||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler, versionHandler *handler.VersionHandler, settingsHandler *handler.SettingsHandler) {
|
||||
r.Any("/openai/*path", withProtocol("openai", proxyHandler.HandleProxy))
|
||||
r.Any("/anthropic/*path", withProtocol("anthropic", proxyHandler.HandleProxy))
|
||||
r.GET("/api/version", versionHandler.GetVersion)
|
||||
|
||||
providers := r.Group("/api/providers")
|
||||
{
|
||||
providers.GET("", providerHandler.ListProviders)
|
||||
providers.POST("", providerHandler.CreateProvider)
|
||||
providers.GET("/:id", providerHandler.GetProvider)
|
||||
providers.PUT("/:id", providerHandler.UpdateProvider)
|
||||
providers.DELETE("/:id", providerHandler.DeleteProvider)
|
||||
}
|
||||
|
||||
models := r.Group("/api/models")
|
||||
{
|
||||
models.GET("", modelHandler.ListModels)
|
||||
models.POST("", modelHandler.CreateModel)
|
||||
models.GET("/:id", modelHandler.GetModel)
|
||||
models.PUT("/:id", modelHandler.UpdateModel)
|
||||
models.DELETE("/:id", modelHandler.DeleteModel)
|
||||
}
|
||||
|
||||
stats := r.Group("/api/stats")
|
||||
{
|
||||
stats.GET("", statsHandler.GetStats)
|
||||
stats.GET("/aggregate", statsHandler.AggregateStats)
|
||||
}
|
||||
|
||||
settings := r.Group("/api/settings")
|
||||
{
|
||||
settings.GET("/startup", settingsHandler.GetStartupSettings)
|
||||
settings.PUT("/startup", settingsHandler.SaveStartupSettings)
|
||||
}
|
||||
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
}
|
||||
|
||||
func withProtocol(protocol string, next gin.HandlerFunc) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Params = append(c.Params, gin.Param{Key: "protocol", Value: protocol})
|
||||
next(c)
|
||||
}
|
||||
}
|
||||
|
||||
func setupStaticFiles(r *gin.Engine) {
|
||||
distFS, err := frontendDistFS()
|
||||
if err != nil {
|
||||
zapLogger.Fatal("无法加载前端资源", zap.Error(err))
|
||||
}
|
||||
setupStaticFilesWithFS(r, distFS)
|
||||
}
|
||||
|
||||
func frontendDistFS() (fs.FS, error) {
|
||||
return fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
||||
}
|
||||
|
||||
func setupStaticFilesWithFS(r *gin.Engine, distFS fs.FS) {
|
||||
getContentType := func(path string) string {
|
||||
if strings.HasSuffix(path, ".js") {
|
||||
return "application/javascript"
|
||||
}
|
||||
if strings.HasSuffix(path, ".css") {
|
||||
return "text/css"
|
||||
}
|
||||
if strings.HasSuffix(path, ".svg") {
|
||||
return "image/svg+xml"
|
||||
}
|
||||
if strings.HasSuffix(path, ".png") {
|
||||
return "image/png"
|
||||
}
|
||||
if strings.HasSuffix(path, ".ico") {
|
||||
return "image/x-icon"
|
||||
}
|
||||
if strings.HasSuffix(path, ".woff") || strings.HasSuffix(path, ".woff2") {
|
||||
return "font/woff2"
|
||||
}
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
r.GET("/assets/*filepath", func(c *gin.Context) {
|
||||
filepath := c.Param("filepath")
|
||||
data, err := fs.ReadFile(distFS, "assets"+filepath)
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, getContentType(filepath), data)
|
||||
})
|
||||
|
||||
r.GET("/icon.png", func(c *gin.Context) {
|
||||
data, err := fs.ReadFile(distFS, "icon.png")
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, "image/png", data)
|
||||
})
|
||||
|
||||
r.NoRoute(func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/openai/") ||
|
||||
strings.HasPrefix(path, "/anthropic/") ||
|
||||
path == "/openai" ||
|
||||
path == "/anthropic" ||
|
||||
strings.HasPrefix(path, "/health") {
|
||||
c.JSON(404, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
|
||||
data, err := fs.ReadFile(distFS, "index.html")
|
||||
if err != nil {
|
||||
c.Status(500)
|
||||
return
|
||||
}
|
||||
c.Data(200, "text/html; charset=utf-8", data)
|
||||
})
|
||||
}
|
||||
|
||||
func setupSystray(port int) {
|
||||
systray.Run(func() {
|
||||
var icon []byte
|
||||
var err error
|
||||
if runtime.GOOS == "windows" {
|
||||
icon, err = embedfs.Assets.ReadFile("assets/icon.ico")
|
||||
} else {
|
||||
icon, err = embedfs.Assets.ReadFile("assets/icon.png")
|
||||
}
|
||||
if err != nil {
|
||||
zapLogger.Error("无法加载托盘图标", zap.Error(err))
|
||||
}
|
||||
systray.SetIcon(icon)
|
||||
systray.SetTooltip(appTooltip)
|
||||
|
||||
mOpen := systray.AddMenuItem("打开管理界面", "在浏览器中打开")
|
||||
systray.AddSeparator()
|
||||
mStatus := systray.AddMenuItem("状态: 运行中", "")
|
||||
mStatus.Disable()
|
||||
mPort := systray.AddMenuItem(desktopPortMenuTitle(port), "")
|
||||
mPort.Disable()
|
||||
systray.AddSeparator()
|
||||
mQuit := systray.AddMenuItem("退出", "停止服务并退出")
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-mOpen.ClickedCh:
|
||||
if err := openBrowser(desktopURL(port)); err != nil {
|
||||
zapLogger.Warn("打开浏览器失败", zap.Error(err))
|
||||
}
|
||||
case <-mQuit.ClickedCh:
|
||||
doShutdown()
|
||||
systray.Quit()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func doShutdown() {
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("正在关闭服务器...")
|
||||
}
|
||||
|
||||
if server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := server.Shutdown(ctx); err != nil && zapLogger != nil {
|
||||
zapLogger.Warn("关闭服务器失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
if shutdownCancel != nil {
|
||||
shutdownCancel()
|
||||
}
|
||||
}
|
||||
|
||||
func getDesktopConfigPath() string {
|
||||
configDir, err := config.GetConfigDir()
|
||||
if err != nil {
|
||||
return "~/.nex/config.yaml"
|
||||
}
|
||||
return filepath.Join(configDir, "config.yaml")
|
||||
}
|
||||
|
||||
func desktopConfigErrorMessage(configPath string, err error) string {
|
||||
return fmt.Sprintf("加载配置失败\n\n配置文件: %s\n\n%v", configPath, err)
|
||||
}
|
||||
|
||||
func desktopListenAddr(port int) string {
|
||||
return fmt.Sprintf(":%d", port)
|
||||
}
|
||||
|
||||
func desktopURL(port int) string {
|
||||
return fmt.Sprintf("http://localhost:%d", port)
|
||||
}
|
||||
|
||||
func desktopPortMenuTitle(port int) string {
|
||||
return fmt.Sprintf("端口: %d", port)
|
||||
}
|
||||
|
||||
func checkPortAvailable(port int) error {
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("端口 %d 已被占用\n\n可能原因:\n- 已有 Nex 实例运行\n- 其他程序占用了该端口\n\n请检查并关闭占用端口的程序", port)
|
||||
}
|
||||
ln.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
type SingletonLock struct {
|
||||
flock *flock.Flock
|
||||
}
|
||||
|
||||
func NewSingletonLock(lockPath string) *SingletonLock {
|
||||
return &SingletonLock{
|
||||
flock: flock.New(lockPath),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SingletonLock) Lock() error {
|
||||
locked, err := s.flock.TryLock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !locked {
|
||||
return fmt.Errorf("已有实例运行")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SingletonLock) Unlock() error {
|
||||
return s.flock.Unlock()
|
||||
}
|
||||
|
||||
func openBrowser(url string) error {
|
||||
var cmd *exec.Cmd
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
cmd = exec.Command("open", url)
|
||||
case "windows":
|
||||
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
|
||||
case "linux":
|
||||
browsers := []string{"xdg-open", "google-chrome", "firefox"}
|
||||
for _, browser := range browsers {
|
||||
if _, err := exec.LookPath(browser); err == nil {
|
||||
cmd = exec.Command(browser, url)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if cmd == nil {
|
||||
return fmt.Errorf("无法打开浏览器")
|
||||
}
|
||||
|
||||
return cmd.Start()
|
||||
}
|
||||
61
backend/cmd/desktop/messagebox_test.go
Normal file
61
backend/cmd/desktop/messagebox_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"syscall"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func withMessageBoxW(t *testing.T, fn func(hwnd, text, caption, flags uintptr) (uintptr, error)) {
|
||||
t.Helper()
|
||||
|
||||
old := callMessageBoxW
|
||||
callMessageBoxW = fn
|
||||
t.Cleanup(func() {
|
||||
callMessageBoxW = old
|
||||
})
|
||||
}
|
||||
|
||||
func TestMessageBoxW_WindowsOnly_InvalidUTF16(t *testing.T) {
|
||||
err := messageBox("bad\x00title", "测试消息", mbIconInformation)
|
||||
if err == nil {
|
||||
t.Fatal("包含 NUL 字符时应该返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageBoxW_WindowsOnly_SuccessIgnoresLastError(t *testing.T) {
|
||||
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
|
||||
return 1, syscall.Errno(123)
|
||||
})
|
||||
|
||||
if err := messageBox("测试标题", "测试消息", mbIconInformation); err != nil {
|
||||
t.Fatalf("MessageBoxW 返回成功时应忽略 last error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageBoxW_WindowsOnly_FailureUsesReturnValue(t *testing.T) {
|
||||
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
|
||||
return 0, syscall.Errno(5)
|
||||
})
|
||||
|
||||
err := messageBox("测试标题", "测试消息", mbIconInformation)
|
||||
if !errors.Is(err, syscall.Errno(5)) {
|
||||
t.Fatalf("MessageBoxW 返回 0 时应返回调用错误: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowError_WindowsBranch(t *testing.T) {
|
||||
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
|
||||
return 0, syscall.Errno(5)
|
||||
})
|
||||
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
t.Fatalf("showError 不应因 MessageBoxW 失败而 panic: %v", recovered)
|
||||
}
|
||||
}()
|
||||
|
||||
showError("测试错误", "这是一条测试错误消息")
|
||||
}
|
||||
9
backend/cmd/desktop/metadata.go
Normal file
9
backend/cmd/desktop/metadata.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package main
|
||||
|
||||
const (
|
||||
appName = "Nex"
|
||||
appTooltip = appName
|
||||
appDescription = "AI Gateway - 统一的大模型 API 网关"
|
||||
// #nosec G101 -- 项目官网地址不是凭据
|
||||
appWebsite = "https://github.com/nex/gateway"
|
||||
)
|
||||
13
backend/cmd/desktop/metadata_test.go
Normal file
13
backend/cmd/desktop/metadata_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDesktopMetadata(t *testing.T) {
|
||||
if appName != "Nex" {
|
||||
t.Fatalf("appName = %q, want %q", appName, "Nex")
|
||||
}
|
||||
|
||||
if appTooltip != appName {
|
||||
t.Fatalf("appTooltip = %q, want %q", appTooltip, appName)
|
||||
}
|
||||
}
|
||||
43
backend/cmd/desktop/migration_test.go
Normal file
43
backend/cmd/desktop/migration_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestDesktop_InitMigrationsWithoutSourceTree(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
origDir, err := os.Getwd()
|
||||
if err == nil {
|
||||
defer func() {
|
||||
if chdirErr := os.Chdir(origDir); chdirErr != nil {
|
||||
t.Logf("无法恢复工作目录: %v", chdirErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
if chdirErr := os.Chdir(tmpDir); chdirErr != nil {
|
||||
t.Skipf("无法切换到临时目录: %v", chdirErr)
|
||||
}
|
||||
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: filepath.Join(tmpDir, "nex-test.db"),
|
||||
MaxIdleConns: 5,
|
||||
MaxOpenConns: 10,
|
||||
ConnMaxLifetime: 0,
|
||||
}
|
||||
|
||||
zapLogger := zap.NewNop()
|
||||
db, err := database.Init(cfg, zapLogger)
|
||||
if err != nil {
|
||||
t.Fatalf("在无源码目录环境下数据库初始化应成功,但返回错误: %v", err)
|
||||
}
|
||||
database.Close(db)
|
||||
}
|
||||
129
backend/cmd/desktop/port_test.go
Normal file
129
backend/cmd/desktop/port_test.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCheckPortAvailable(t *testing.T) {
|
||||
port := 19826
|
||||
|
||||
err := checkPortAvailable(port)
|
||||
if err != nil {
|
||||
t.Fatalf("端口 %d 应该可用: %v", port, err)
|
||||
}
|
||||
|
||||
t.Log("端口可用测试通过")
|
||||
}
|
||||
|
||||
func TestCheckPortOccupied(t *testing.T) {
|
||||
port := 19827
|
||||
|
||||
listener, err := net.Listen("tcp", ":19827") //nolint:gosec // 需要验证 checkPortAvailable 对通配地址占用的检测行为
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = checkPortAvailable(port)
|
||||
if err == nil {
|
||||
t.Fatal("端口被占用时应该返回错误")
|
||||
}
|
||||
|
||||
t.Log("端口占用检测测试通过")
|
||||
}
|
||||
|
||||
func TestCheckPortAvailableAfterClose(t *testing.T) {
|
||||
port := 19828
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:19828")
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
|
||||
server := &http.Server{ReadHeaderTimeout: time.Second}
|
||||
defer server.Close()
|
||||
go func() {
|
||||
err := server.Serve(listener)
|
||||
if err != nil && err != http.ErrServerClosed && !errors.Is(err, net.ErrClosed) {
|
||||
t.Errorf("serve failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
listener.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = checkPortAvailable(port)
|
||||
if err != nil {
|
||||
t.Fatalf("端口关闭后应该可用: %v", err)
|
||||
}
|
||||
|
||||
t.Log("端口关闭后可用测试通过")
|
||||
}
|
||||
|
||||
func TestCheckPortAvailableErrorContainsPort(t *testing.T) {
|
||||
port := 19829
|
||||
|
||||
listener, err := net.Listen("tcp", ":19829") //nolint:gosec
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = checkPortAvailable(port)
|
||||
if err == nil {
|
||||
t.Fatal("端口被占用时应该返回错误")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "19829") {
|
||||
t.Fatalf("错误信息应包含端口号 19829,实际: %v", err)
|
||||
}
|
||||
|
||||
t.Log("端口错误信息包含端口号测试通过")
|
||||
}
|
||||
|
||||
func TestGetDesktopConfigPath(t *testing.T) {
|
||||
path := getDesktopConfigPath()
|
||||
if path == "" {
|
||||
t.Fatal("getDesktopConfigPath 应返回非空路径")
|
||||
}
|
||||
if !strings.Contains(path, "config.yaml") {
|
||||
t.Fatalf("路径应包含 config.yaml,实际: %s", path)
|
||||
}
|
||||
t.Log("getDesktopConfigPath 测试通过")
|
||||
}
|
||||
|
||||
func TestDesktopConfiguredPortHelpers(t *testing.T) {
|
||||
port := 19830
|
||||
|
||||
if got := desktopListenAddr(port); got != ":19830" {
|
||||
t.Fatalf("HTTP 监听地址应使用配置端口,实际: %s", got)
|
||||
}
|
||||
if got := desktopURL(port); got != "http://localhost:19830" {
|
||||
t.Fatalf("浏览器 URL 应使用配置端口,实际: %s", got)
|
||||
}
|
||||
if got := desktopPortMenuTitle(port); got != "端口: 19830" {
|
||||
t.Fatalf("托盘端口显示应使用配置端口,实际: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDesktopConfigErrorMessageContainsPathAndReason(t *testing.T) {
|
||||
msg := desktopConfigErrorMessage("/tmp/nex/config.yaml", errors.New("yaml parse failed"))
|
||||
|
||||
if !strings.Contains(msg, "/tmp/nex/config.yaml") {
|
||||
t.Fatalf("配置错误提示应包含配置路径,实际: %s", msg)
|
||||
}
|
||||
if !strings.Contains(msg, "yaml parse failed") {
|
||||
t.Fatalf("配置错误提示应包含失败原因,实际: %s", msg)
|
||||
}
|
||||
}
|
||||
44
backend/cmd/desktop/routes_test.go
Normal file
44
backend/cmd/desktop/routes_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"nex/backend/internal/handler"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestSetupRoutes_VersionDoesNotFallback(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
setupRoutes(r, &handler.ProxyHandler{}, &handler.ProviderHandler{}, &handler.ModelHandler{}, &handler.StatsHandler{}, handler.NewVersionHandler(), handler.NewSettingsHandler(nil, "desktop", true, ""))
|
||||
setupStaticFilesWithFS(r, fstest.MapFS{
|
||||
"index.html": {Data: []byte("<html>fallback</html>")},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if contentType := w.Header().Get("Content-Type"); contentType == "text/html; charset=utf-8" {
|
||||
t.Fatalf("版本接口不应返回 SPA fallback HTML")
|
||||
}
|
||||
|
||||
var result map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
for _, key := range []string{"version", "commit", "build_time"} {
|
||||
if result[key] == "" {
|
||||
t.Fatalf("响应缺少 %s 字段: %#v", key, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
74
backend/cmd/desktop/singleton_test.go
Normal file
74
backend/cmd/desktop/singleton_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSingletonLock_FirstLockSuccess(t *testing.T) {
|
||||
lockPath := filepath.Join(os.TempDir(), "nex-gateway-test-first.lock")
|
||||
defer os.Remove(lockPath)
|
||||
|
||||
lock := NewSingletonLock(lockPath)
|
||||
if err := lock.Lock(); err != nil {
|
||||
t.Fatalf("首次加锁应成功,但返回错误: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := lock.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func TestSingletonLock_DuplicateLockFails(t *testing.T) {
|
||||
lockPath := filepath.Join(os.TempDir(), "nex-gateway-test-dup.lock")
|
||||
defer os.Remove(lockPath)
|
||||
|
||||
lock1 := NewSingletonLock(lockPath)
|
||||
if err := lock1.Lock(); err != nil {
|
||||
t.Fatalf("首次加锁应成功: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := lock1.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
lock2 := NewSingletonLock(lockPath)
|
||||
err := lock2.Lock()
|
||||
if err == nil {
|
||||
if unlockErr := lock2.Unlock(); unlockErr != nil {
|
||||
t.Fatalf("解锁失败: %v", unlockErr)
|
||||
}
|
||||
t.Fatal("重复加锁应失败,但返回 nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingletonLock_UnlockThenRelock(t *testing.T) {
|
||||
lockPath := filepath.Join(os.TempDir(), "nex-gateway-test-relock.lock")
|
||||
defer os.Remove(lockPath)
|
||||
|
||||
lock1 := NewSingletonLock(lockPath)
|
||||
if err := lock1.Lock(); err != nil {
|
||||
t.Fatalf("首次加锁应成功: %v", err)
|
||||
}
|
||||
if err := lock1.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
|
||||
lock2 := NewSingletonLock(lockPath)
|
||||
if err := lock2.Lock(); err != nil {
|
||||
t.Fatalf("释放后重新加锁应成功: %v", err)
|
||||
}
|
||||
if err := lock2.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingletonLock_UnlockWithoutLock(t *testing.T) {
|
||||
lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock"))
|
||||
if err := lock.Unlock(); err != nil {
|
||||
t.Fatalf("未加锁时解锁失败: %v", err)
|
||||
}
|
||||
}
|
||||
238
backend/cmd/desktop/static_test.go
Normal file
238
backend/cmd/desktop/static_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestSetupStaticFiles(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
distFS, err := frontendDistFS()
|
||||
if err != nil {
|
||||
t.Skipf("跳过测试: 前端资源未构建: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
setupStaticFilesWithFS(r, distFS)
|
||||
|
||||
t.Run("API 404", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 404 {
|
||||
t.Errorf("期望状态码 404, 实际 %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OpenAI proxy prefix 404", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/openai/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望状态码 404, 实际 %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "not found") {
|
||||
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Anthropic proxy prefix 404", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/anthropic/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望状态码 404, 实际 %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "not found") {
|
||||
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SPA fallback", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/providers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MIME type for JS", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/assets/test.js", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == 200 {
|
||||
expected := "application/javascript"
|
||||
if w.Header().Get("Content-Type") != expected {
|
||||
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
||||
}
|
||||
} else {
|
||||
t.Log("文件不存在,跳过 MIME 类型验证")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MIME type for CSS", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/assets/test.css", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == 200 {
|
||||
expected := "text/css"
|
||||
if w.Header().Get("Content-Type") != expected {
|
||||
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
||||
}
|
||||
} else {
|
||||
t.Log("文件不存在,跳过 MIME 类型验证")
|
||||
}
|
||||
})
|
||||
|
||||
t.Log("静态文件服务测试通过")
|
||||
}
|
||||
|
||||
func TestSetupStaticFilesWithFS_IconPNG(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
setupStaticFilesWithFS(r, fstest.MapFS{
|
||||
"icon.png": {Data: []byte("png")},
|
||||
"index.html": {Data: []byte("<html>fallback</html>")},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/icon.png", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
if w.Header().Get("Content-Type") != "image/png" {
|
||||
t.Fatalf("期望 Content-Type image/png, 实际 %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
if w.Body.String() != "png" {
|
||||
t.Fatalf("期望返回 PNG 内容,实际 %q", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithProtocolAndStaticRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
distFS, err := frontendDistFS()
|
||||
if err != nil {
|
||||
t.Skipf("跳过测试: 前端资源未构建: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
|
||||
var gotProtocol string
|
||||
var gotPath string
|
||||
r.Any("/openai/*path", withProtocol("openai", func(c *gin.Context) {
|
||||
gotProtocol = c.Param("protocol")
|
||||
gotPath = c.Param("path")
|
||||
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
|
||||
}))
|
||||
r.Any("/anthropic/*path", withProtocol("anthropic", func(c *gin.Context) {
|
||||
gotProtocol = c.Param("protocol")
|
||||
gotPath = c.Param("path")
|
||||
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
|
||||
}))
|
||||
setupStaticFilesWithFS(r, distFS)
|
||||
|
||||
t.Run("OpenAI route enters proxy handler wrapper", func(t *testing.T) {
|
||||
gotProtocol = ""
|
||||
gotPath = ""
|
||||
|
||||
req := httptest.NewRequest("POST", "/openai/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
if gotProtocol != "openai" {
|
||||
t.Errorf("期望 protocol=openai, 实际 %s", gotProtocol)
|
||||
}
|
||||
if gotPath != "/v1/chat/completions" {
|
||||
t.Errorf("期望 path=/v1/chat/completions, 实际 %s", gotPath)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Anthropic route enters proxy handler wrapper", func(t *testing.T) {
|
||||
gotProtocol = ""
|
||||
gotPath = ""
|
||||
|
||||
req := httptest.NewRequest("POST", "/anthropic/v1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
if gotProtocol != "anthropic" {
|
||||
t.Errorf("期望 protocol=anthropic, 实际 %s", gotProtocol)
|
||||
}
|
||||
if gotPath != "/v1/messages" {
|
||||
t.Errorf("期望 path=/v1/messages, 实际 %s", gotPath)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Static assets are not hijacked", func(t *testing.T) {
|
||||
gotProtocol = ""
|
||||
gotPath = ""
|
||||
|
||||
req := httptest.NewRequest("GET", "/assets/test.js", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if gotProtocol != "" || gotPath != "" {
|
||||
t.Errorf("静态资源不应进入代理包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
|
||||
}
|
||||
if w.Code == http.StatusOK {
|
||||
if !strings.HasPrefix(w.Header().Get("Content-Type"), "application/javascript") {
|
||||
t.Errorf("期望 JS Content-Type, 实际 %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
return
|
||||
}
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望静态资源返回 200 或 404, 实际 %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SPA path keeps fallback", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/providers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Header().Get("Content-Type"), "text/html") {
|
||||
t.Errorf("期望返回 HTML,实际 %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Unknown proxy-like path does not return index html", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/openai/unknown", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("显式代理路由应进入代理包装器,实际状态码 %d", w.Code)
|
||||
}
|
||||
if gotProtocol != "openai" || gotPath != "/unknown" {
|
||||
t.Errorf("期望 unknown 代理路径进入 openai 包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -3,46 +3,38 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pressly/goose/v3"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/conversion/openai"
|
||||
"nex/backend/internal/database"
|
||||
"nex/backend/internal/handler"
|
||||
"nex/backend/internal/handler/middleware"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
"nex/backend/pkg/buildinfo"
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 1. 加载配置(已包含 CLI 参数解析、环境变量绑定、配置文件读取和验证)
|
||||
cfg, err := config.LoadConfig()
|
||||
minimalLogger := pkgLogger.NewMinimal()
|
||||
|
||||
cfg, cfgMeta, err := config.LoadServerConfigWithMetadata()
|
||||
if err != nil {
|
||||
log.Fatalf("加载配置失败: %v", err)
|
||||
minimalLogger.Fatal("加载配置失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 2. 打印配置摘要
|
||||
cfg.PrintSummary()
|
||||
|
||||
// 3. 初始化日志
|
||||
zapLogger, err := pkgLogger.New(pkgLogger.Config{
|
||||
zapLogger, err := pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
|
||||
Level: cfg.Log.Level,
|
||||
Path: cfg.Log.Path,
|
||||
MaxSize: cfg.Log.MaxSize,
|
||||
@@ -51,48 +43,59 @@ func main() {
|
||||
Compress: cfg.Log.Compress,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("初始化日志失败: %v", err)
|
||||
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
|
||||
}
|
||||
defer zapLogger.Sync()
|
||||
defer func() {
|
||||
if err := zapLogger.Sync(); err != nil {
|
||||
minimalLogger.Warn("同步日志失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
// 3. 初始化数据库
|
||||
db, err := initDatabase(cfg)
|
||||
cfg.PrintSummary(zapLogger)
|
||||
|
||||
db, err := database.Init(&cfg.Database, zapLogger)
|
||||
if err != nil {
|
||||
zapLogger.Fatal("初始化数据库失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("初始化数据库失败", zap.Error(err))
|
||||
}
|
||||
defer closeDB(db)
|
||||
defer database.Close(db)
|
||||
|
||||
// 4. 初始化 repository 层
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
// 5. 初始化 service 层
|
||||
providerService := service.NewProviderService(providerRepo)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
|
||||
if err := routingCache.Preload(); err != nil {
|
||||
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
|
||||
}
|
||||
|
||||
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
|
||||
service.WithFlushInterval(5*time.Second),
|
||||
service.WithFlushThreshold(100))
|
||||
statsBuffer.Start()
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
||||
routingService := service.NewRoutingService(routingCache)
|
||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||
|
||||
// 6. 创建 ConversionEngine
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
|
||||
}
|
||||
if err := registry.Register(anthropic.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
|
||||
}
|
||||
engine := conversion.NewConversionEngine(registry, zapLogger)
|
||||
|
||||
// 7. 初始化 provider client
|
||||
providerClient := provider.NewClient()
|
||||
providerClient := provider.NewClient(zapLogger)
|
||||
|
||||
// 8. 初始化 handler 层
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
statsHandler := handler.NewStatsHandler(statsService)
|
||||
versionHandler := handler.NewVersionHandler()
|
||||
settingsHandler := handler.NewSettingsHandler(cfg, "server", false, cfgMeta.ConfigPath)
|
||||
|
||||
// 9. 创建 Gin 引擎
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.New()
|
||||
|
||||
@@ -101,20 +104,23 @@ func main() {
|
||||
r.Use(middleware.Logging(zapLogger))
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler, versionHandler, settingsHandler)
|
||||
|
||||
// 10. 启动服务器
|
||||
srv := &http.Server{
|
||||
Addr: formatAddr(cfg.Server.Port),
|
||||
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
|
||||
Handler: r,
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
}
|
||||
|
||||
go func() {
|
||||
zapLogger.Info("AI Gateway 启动", zap.String("addr", srv.Addr))
|
||||
zapLogger.Info("AI Gateway 启动",
|
||||
zap.String("addr", srv.Addr),
|
||||
zap.String("version", buildinfo.Version()),
|
||||
zap.String("commit", buildinfo.Commit()),
|
||||
zap.String("build_time", buildinfo.BuildTime()))
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("服务器启动失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -128,89 +134,18 @@ func main() {
|
||||
defer cancel()
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("服务器强制关闭", zap.Error(err))
|
||||
}
|
||||
|
||||
statsBuffer.Stop()
|
||||
|
||||
zapLogger.Info("服务器已关闭")
|
||||
}
|
||||
|
||||
func initDatabase(cfg *config.Config) (*gorm.DB, error) {
|
||||
db, err := gorm.Open(sqlite.Open(cfg.Database.Path), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler, versionHandler *handler.VersionHandler, settingsHandler *handler.SettingsHandler) {
|
||||
r.Any("/:protocol/*path", proxyHandler.HandleProxy)
|
||||
r.GET("/api/version", versionHandler.GetVersion)
|
||||
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
log.Printf("警告: 启用 WAL 模式失败: %v", err)
|
||||
}
|
||||
|
||||
if err := runMigrations(db); err != nil {
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime)
|
||||
|
||||
log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v",
|
||||
cfg.Database.MaxIdleConns, cfg.Database.MaxOpenConns, cfg.Database.ConnMaxLifetime)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func runMigrations(db *gorm.DB) error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
migrationsDir := getMigrationsDir()
|
||||
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
|
||||
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
|
||||
}
|
||||
|
||||
goose.SetDialect("sqlite3")
|
||||
if err := goose.Up(sqlDB, migrationsDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMigrationsDir() string {
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if ok {
|
||||
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations")
|
||||
if abs, err := filepath.Abs(dir); err == nil {
|
||||
return abs
|
||||
}
|
||||
}
|
||||
return "./migrations"
|
||||
}
|
||||
|
||||
func closeDB(db *gorm.DB) {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sqlDB.Close()
|
||||
}
|
||||
|
||||
func formatAddr(port int) string {
|
||||
return fmt.Sprintf(":%d", port)
|
||||
}
|
||||
|
||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
|
||||
// 统一代理入口: /{protocol}/v1/{path}
|
||||
r.Any("/:protocol/v1/*path", proxyHandler.HandleProxy)
|
||||
|
||||
// 供应商管理 API
|
||||
providers := r.Group("/api/providers")
|
||||
{
|
||||
providers.GET("", providerHandler.ListProviders)
|
||||
@@ -220,7 +155,6 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
|
||||
providers.DELETE("/:id", providerHandler.DeleteProvider)
|
||||
}
|
||||
|
||||
// 模型管理 API
|
||||
models := r.Group("/api/models")
|
||||
{
|
||||
models.GET("", modelHandler.ListModels)
|
||||
@@ -230,14 +164,18 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
|
||||
models.DELETE("/:id", modelHandler.DeleteModel)
|
||||
}
|
||||
|
||||
// 统计查询 API
|
||||
stats := r.Group("/api/stats")
|
||||
{
|
||||
stats.GET("", statsHandler.GetStats)
|
||||
stats.GET("/aggregate", statsHandler.AggregateStats)
|
||||
}
|
||||
|
||||
// 健康检查
|
||||
settings := r.Group("/api/settings")
|
||||
{
|
||||
settings.GET("/startup", settingsHandler.GetStartupSettings)
|
||||
settings.PUT("/startup", settingsHandler.SaveStartupSettings)
|
||||
}
|
||||
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
37
backend/cmd/server/routes_test.go
Normal file
37
backend/cmd/server/routes_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/handler"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestSetupRoutes_Version(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
setupRoutes(r, &handler.ProxyHandler{}, &handler.ProviderHandler{}, &handler.ModelHandler{}, &handler.StatsHandler{}, handler.NewVersionHandler(), handler.NewSettingsHandler(nil, "server", false, ""))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var result map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
for _, key := range []string{"version", "commit", "build_time"} {
|
||||
if result[key] == "" {
|
||||
t.Fatalf("响应缺少 %s 字段: %#v", key, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
180
backend/go.mod
180
backend/go.mod
@@ -2,69 +2,249 @@ module nex/backend
|
||||
|
||||
go 1.26.2
|
||||
|
||||
replace nex/embedfs => ../embedfs
|
||||
|
||||
tool (
|
||||
github.com/golangci/golangci-lint/cmd/golangci-lint
|
||||
go.uber.org/mock/mockgen
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/getlantern/systray v1.2.2
|
||||
github.com/gin-gonic/gin v1.12.0
|
||||
github.com/go-playground/validator/v10 v10.30.2
|
||||
github.com/gofrs/flock v0.13.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/pressly/goose/v3 v3.27.0
|
||||
github.com/spf13/pflag v1.0.10
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
go.uber.org/mock v0.6.0
|
||||
go.uber.org/zap v1.27.1
|
||||
gopkg.in/lumberjack.v2 v2.0.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.6.0
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
gorm.io/gorm v1.31.1
|
||||
nex/embedfs v0.0.0-00010101000000-000000000000
|
||||
)
|
||||
|
||||
require (
|
||||
4d63.com/gocheckcompilerdirectives v1.3.0 // indirect
|
||||
4d63.com/gochecknoglobals v0.2.2 // indirect
|
||||
filippo.io/edwards25519 v1.2.0 // indirect
|
||||
github.com/4meepo/tagalign v1.4.2 // indirect
|
||||
github.com/Abirdcfly/dupword v0.1.3 // indirect
|
||||
github.com/Antonboom/errname v1.0.0 // indirect
|
||||
github.com/Antonboom/nilnil v1.0.1 // indirect
|
||||
github.com/Antonboom/testifylint v1.5.2 // indirect
|
||||
github.com/BurntSushi/toml v1.6.0 // indirect
|
||||
github.com/Crocmagnon/fatcontext v0.7.1 // indirect
|
||||
github.com/Djarvur/go-err113 v0.0.0-20210108212216-aea10b59be24 // indirect
|
||||
github.com/GaijinEntertainment/go-exhaustruct/v3 v3.3.1 // indirect
|
||||
github.com/Masterminds/semver/v3 v3.3.0 // indirect
|
||||
github.com/OpenPeeDeeP/depguard/v2 v2.2.1 // indirect
|
||||
github.com/alecthomas/go-check-sumtype v0.3.1 // indirect
|
||||
github.com/alexkohler/nakedret/v2 v2.0.5 // indirect
|
||||
github.com/alexkohler/prealloc v1.0.0 // indirect
|
||||
github.com/alingse/asasalint v0.0.11 // indirect
|
||||
github.com/alingse/nilnesserr v0.1.2 // indirect
|
||||
github.com/ashanbrown/forbidigo v1.6.0 // indirect
|
||||
github.com/ashanbrown/makezero v1.2.0 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bkielbasa/cyclop v1.2.3 // indirect
|
||||
github.com/blizzy78/varnamelen v0.8.0 // indirect
|
||||
github.com/bombsimon/wsl/v4 v4.5.0 // indirect
|
||||
github.com/breml/bidichk v0.3.2 // indirect
|
||||
github.com/breml/errchkjson v0.4.0 // indirect
|
||||
github.com/butuzov/ireturn v0.3.1 // indirect
|
||||
github.com/butuzov/mirror v1.3.0 // indirect
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/bytedance/sonic v1.15.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/catenacyber/perfsprint v0.8.2 // indirect
|
||||
github.com/ccojocar/zxcvbn-go v1.0.2 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charithe/durationcheck v0.0.10 // indirect
|
||||
github.com/chavacava/garif v0.1.0 // indirect
|
||||
github.com/ckaznocha/intrange v0.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/curioswitch/go-reassign v0.3.0 // indirect
|
||||
github.com/daixiang0/gci v0.13.5 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/denis-tingaikin/go-header v0.5.0 // indirect
|
||||
github.com/ettle/strcase v0.2.0 // indirect
|
||||
github.com/fatih/color v1.18.0 // indirect
|
||||
github.com/fatih/structtag v1.2.0 // indirect
|
||||
github.com/firefart/nonamedreturns v1.0.5 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/fzipp/gocyclo v0.6.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect
|
||||
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect
|
||||
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 // indirect
|
||||
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 // indirect
|
||||
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 // indirect
|
||||
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
|
||||
github.com/ghostiam/protogetter v0.3.9 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/go-critic/go-critic v0.12.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
||||
github.com/go-stack/stack v1.8.0 // indirect
|
||||
github.com/go-toolsmith/astcast v1.1.0 // indirect
|
||||
github.com/go-toolsmith/astcopy v1.1.0 // indirect
|
||||
github.com/go-toolsmith/astequal v1.2.0 // indirect
|
||||
github.com/go-toolsmith/astfmt v1.1.0 // indirect
|
||||
github.com/go-toolsmith/astp v1.1.0 // indirect
|
||||
github.com/go-toolsmith/strparse v1.1.0 // indirect
|
||||
github.com/go-toolsmith/typep v1.1.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/go-xmlfmt/xmlfmt v1.1.3 // indirect
|
||||
github.com/gobwas/glob v0.2.3 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/golangci/dupl v0.0.0-20250308024227-f665c8d69b32 // indirect
|
||||
github.com/golangci/go-printf-func-name v0.1.0 // indirect
|
||||
github.com/golangci/gofmt v0.0.0-20250106114630-d62b90e6713d // indirect
|
||||
github.com/golangci/golangci-lint v1.64.8 // indirect
|
||||
github.com/golangci/misspell v0.6.0 // indirect
|
||||
github.com/golangci/plugin-module-register v0.1.1 // indirect
|
||||
github.com/golangci/revgrep v0.8.0 // indirect
|
||||
github.com/golangci/unconvert v0.0.0-20240309020433-c5143eacb3ed // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/gordonklaus/ineffassign v0.1.0 // indirect
|
||||
github.com/gostaticanalysis/analysisutil v0.7.1 // indirect
|
||||
github.com/gostaticanalysis/comment v1.5.0 // indirect
|
||||
github.com/gostaticanalysis/forcetypeassert v0.2.0 // indirect
|
||||
github.com/gostaticanalysis/nilerr v0.1.1 // indirect
|
||||
github.com/hashicorp/go-immutable-radix/v2 v2.1.0 // indirect
|
||||
github.com/hashicorp/go-version v1.7.0 // indirect
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
|
||||
github.com/hexops/gotextdiff v1.0.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jgautheron/goconst v1.7.1 // indirect
|
||||
github.com/jingyugao/rowserrcheck v1.1.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/jjti/go-spancheck v0.6.4 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/julz/importas v0.2.0 // indirect
|
||||
github.com/karamaru-alpha/copyloopvar v1.2.1 // indirect
|
||||
github.com/kisielk/errcheck v1.9.0 // indirect
|
||||
github.com/kkHAIKE/contextcheck v1.1.6 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/kulti/thelper v0.6.3 // indirect
|
||||
github.com/kunwardeep/paralleltest v1.0.10 // indirect
|
||||
github.com/lasiar/canonicalheader v1.1.2 // indirect
|
||||
github.com/ldez/exptostd v0.4.2 // indirect
|
||||
github.com/ldez/gomoddirectives v0.6.1 // indirect
|
||||
github.com/ldez/grignotin v0.9.0 // indirect
|
||||
github.com/ldez/tagliatelle v0.7.1 // indirect
|
||||
github.com/ldez/usetesting v0.4.2 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/leonklingele/grouper v1.1.2 // indirect
|
||||
github.com/macabu/inamedparam v0.1.3 // indirect
|
||||
github.com/maratori/testableexamples v1.0.0 // indirect
|
||||
github.com/maratori/testpackage v1.1.1 // indirect
|
||||
github.com/matoous/godox v1.1.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
|
||||
github.com/mfridman/interpolate v0.0.2 // indirect
|
||||
github.com/mgechev/revive v1.7.0 // indirect
|
||||
github.com/mitchellh/go-homedir v1.1.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/moricho/tparallel v0.3.2 // indirect
|
||||
github.com/nakabonne/nestif v0.3.1 // indirect
|
||||
github.com/nishanths/exhaustive v0.12.0 // indirect
|
||||
github.com/nishanths/predeclared v0.2.2 // indirect
|
||||
github.com/nunnatsa/ginkgolinter v0.19.1 // indirect
|
||||
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
||||
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/polyfloyd/go-errorlint v1.7.1 // indirect
|
||||
github.com/prometheus/client_golang v1.12.1 // indirect
|
||||
github.com/prometheus/client_model v0.2.0 // indirect
|
||||
github.com/prometheus/common v0.32.1 // indirect
|
||||
github.com/prometheus/procfs v0.19.2 // indirect
|
||||
github.com/quasilyte/go-ruleguard v0.4.3-0.20240823090925-0fe6f58b47b1 // indirect
|
||||
github.com/quasilyte/go-ruleguard/dsl v0.3.22 // indirect
|
||||
github.com/quasilyte/gogrep v0.5.0 // indirect
|
||||
github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727 // indirect
|
||||
github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567 // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||
github.com/raeperd/recvcheck v0.2.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
github.com/ryancurrah/gomodguard v1.3.5 // indirect
|
||||
github.com/ryanrolds/sqlclosecheck v0.5.1 // indirect
|
||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||
github.com/sanposhiho/wastedassign/v2 v2.1.0 // indirect
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 // indirect
|
||||
github.com/sashamelentyev/interfacebloat v1.1.0 // indirect
|
||||
github.com/sashamelentyev/usestdlibvars v1.28.0 // indirect
|
||||
github.com/securego/gosec/v2 v2.22.2 // indirect
|
||||
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/sivchari/containedctx v1.0.3 // indirect
|
||||
github.com/sivchari/tenv v1.12.1 // indirect
|
||||
github.com/sonatard/noctx v0.1.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/sourcegraph/go-diff v0.7.0 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/cobra v1.9.1 // indirect
|
||||
github.com/ssgreg/nlreturn/v2 v2.2.1 // indirect
|
||||
github.com/stbenjam/no-sprintf-host-port v0.2.0 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tdakkota/asciicheck v0.4.1 // indirect
|
||||
github.com/tetafro/godot v1.5.0 // indirect
|
||||
github.com/timakin/bodyclose v0.0.0-20241017074812-ed6a65f985e3 // indirect
|
||||
github.com/timonwong/loggercheck v0.10.1 // indirect
|
||||
github.com/tomarrell/wrapcheck/v2 v2.10.0 // indirect
|
||||
github.com/tommy-muehle/go-mnd/v2 v2.5.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.1 // indirect
|
||||
github.com/ultraware/funlen v0.2.0 // indirect
|
||||
github.com/ultraware/whitespace v0.2.0 // indirect
|
||||
github.com/uudashr/gocognit v1.2.0 // indirect
|
||||
github.com/uudashr/iface v1.3.1 // indirect
|
||||
github.com/xen0n/gosmopolitan v1.2.2 // indirect
|
||||
github.com/yagipy/maintidx v1.0.0 // indirect
|
||||
github.com/yeya24/promlinter v0.3.0 // indirect
|
||||
github.com/ykadowak/zerologlint v0.1.5 // indirect
|
||||
gitlab.com/bosi/decorder v0.4.2 // indirect
|
||||
go-simpler.org/musttag v0.13.0 // indirect
|
||||
go-simpler.org/sloglint v0.9.0 // indirect
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/arch v0.22.0 // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac // indirect
|
||||
golang.org/x/mod v0.33.0 // indirect
|
||||
golang.org/x/net v0.51.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
golang.org/x/tools v0.42.0 // indirect
|
||||
golang.org/x/tools/go/expect v0.1.1-deprecated // indirect
|
||||
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
honnef.co/go/tools v0.6.1 // indirect
|
||||
mvdan.cc/gofumpt v0.7.0 // indirect
|
||||
mvdan.cc/unparam v0.0.0-20240528143540-8a5130ca722f // indirect
|
||||
)
|
||||
|
||||
919
backend/go.sum
919
backend/go.sum
File diff suppressed because it is too large
Load Diff
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
@@ -32,7 +33,13 @@ type ServerConfig struct {
|
||||
|
||||
// DatabaseConfig 数据库配置
|
||||
type DatabaseConfig struct {
|
||||
Path string `yaml:"path" mapstructure:"path" validate:"required"`
|
||||
Driver string `yaml:"driver" mapstructure:"driver" validate:"required,oneof=sqlite mysql"`
|
||||
Path string `yaml:"path" mapstructure:"path" validate:"required_if=Driver sqlite"`
|
||||
Host string `yaml:"host" mapstructure:"host" validate:"required_if=Driver mysql"`
|
||||
Port int `yaml:"port" mapstructure:"port" validate:"required_if=Driver mysql,omitempty,min=1,max=65535"`
|
||||
User string `yaml:"user" mapstructure:"user" validate:"required_if=Driver mysql"`
|
||||
Password string `yaml:"password" mapstructure:"password"`
|
||||
DBName string `yaml:"dbname" mapstructure:"dbname" validate:"required_if=Driver mysql"`
|
||||
MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" validate:"required,min=1"`
|
||||
MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" validate:"required,min=1"`
|
||||
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" mapstructure:"conn_max_lifetime" validate:"required"`
|
||||
@@ -51,7 +58,10 @@ type LogConfig struct {
|
||||
// DefaultConfig returns default config values
|
||||
func DefaultConfig() *Config {
|
||||
// Use home dir for default paths
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
homeDir = "."
|
||||
}
|
||||
nexDir := filepath.Join(homeDir, ".nex")
|
||||
|
||||
return &Config{
|
||||
@@ -61,7 +71,13 @@ func DefaultConfig() *Config {
|
||||
WriteTimeout: 30 * time.Second,
|
||||
},
|
||||
Database: DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: filepath.Join(nexDir, "config.db"),
|
||||
Host: "",
|
||||
Port: 3306,
|
||||
User: "",
|
||||
Password: "",
|
||||
DBName: "nex",
|
||||
MaxIdleConns: 10,
|
||||
MaxOpenConns: 100,
|
||||
ConnMaxLifetime: 1 * time.Hour,
|
||||
@@ -84,7 +100,7 @@ func GetConfigDir() (string, error) {
|
||||
return "", err
|
||||
}
|
||||
configDir := filepath.Join(homeDir, ".nex")
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return configDir, nil
|
||||
@@ -110,14 +126,23 @@ func GetConfigPath() (string, error) {
|
||||
|
||||
// setupDefaults 设置默认配置值
|
||||
func setupDefaults(v *viper.Viper) {
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
homeDir = "."
|
||||
}
|
||||
nexDir := filepath.Join(homeDir, ".nex")
|
||||
|
||||
v.SetDefault("server.port", 9826)
|
||||
v.SetDefault("server.read_timeout", "30s")
|
||||
v.SetDefault("server.write_timeout", "30s")
|
||||
|
||||
v.SetDefault("database.driver", "sqlite")
|
||||
v.SetDefault("database.path", filepath.Join(nexDir, "config.db"))
|
||||
v.SetDefault("database.host", "")
|
||||
v.SetDefault("database.port", 3306)
|
||||
v.SetDefault("database.user", "")
|
||||
v.SetDefault("database.password", "")
|
||||
v.SetDefault("database.dbname", "nex")
|
||||
v.SetDefault("database.max_idle_conns", 10)
|
||||
v.SetDefault("database.max_open_conns", 100)
|
||||
v.SetDefault("database.conn_max_lifetime", "1h")
|
||||
@@ -138,7 +163,13 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
|
||||
flagSet.Duration("server-read-timeout", 0, "读超时")
|
||||
flagSet.Duration("server-write-timeout", 0, "写超时")
|
||||
|
||||
flagSet.String("database-driver", "", "数据库驱动:sqlite/mysql")
|
||||
flagSet.String("database-path", "", "数据库文件路径")
|
||||
flagSet.String("database-host", "", "MySQL 主机地址")
|
||||
flagSet.Int("database-port", 0, "MySQL 端口")
|
||||
flagSet.String("database-user", "", "MySQL 用户名")
|
||||
flagSet.String("database-password", "", "MySQL 密码")
|
||||
flagSet.String("database-dbname", "", "MySQL 数据库名")
|
||||
flagSet.Int("database-max-idle-conns", 0, "最大空闲连接数")
|
||||
flagSet.Int("database-max-open-conns", 0, "最大打开连接数")
|
||||
flagSet.Duration("database-conn-max-lifetime", 0, "连接最大生命周期")
|
||||
@@ -152,21 +183,33 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
|
||||
|
||||
// 绑定所有 flag 到 viper
|
||||
// 注意:必须在设置默认值之后绑定
|
||||
v.BindPFlag("server.port", flagSet.Lookup("server-port"))
|
||||
v.BindPFlag("server.read_timeout", flagSet.Lookup("server-read-timeout"))
|
||||
v.BindPFlag("server.write_timeout", flagSet.Lookup("server-write-timeout"))
|
||||
bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
|
||||
bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
|
||||
bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
|
||||
|
||||
v.BindPFlag("database.path", flagSet.Lookup("database-path"))
|
||||
v.BindPFlag("database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
|
||||
v.BindPFlag("database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
|
||||
v.BindPFlag("database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
|
||||
bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
|
||||
bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
|
||||
bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
|
||||
bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
|
||||
bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
|
||||
bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
|
||||
bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
|
||||
bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
|
||||
bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
|
||||
bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
|
||||
|
||||
v.BindPFlag("log.level", flagSet.Lookup("log-level"))
|
||||
v.BindPFlag("log.path", flagSet.Lookup("log-path"))
|
||||
v.BindPFlag("log.max_size", flagSet.Lookup("log-max-size"))
|
||||
v.BindPFlag("log.max_backups", flagSet.Lookup("log-max-backups"))
|
||||
v.BindPFlag("log.max_age", flagSet.Lookup("log-max-age"))
|
||||
v.BindPFlag("log.compress", flagSet.Lookup("log-compress"))
|
||||
bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
|
||||
bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
|
||||
bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
|
||||
bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
|
||||
bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
|
||||
bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
|
||||
}
|
||||
|
||||
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
|
||||
if err := v.BindPFlag(key, flag); err != nil {
|
||||
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
|
||||
}
|
||||
}
|
||||
|
||||
// setupEnv 绑定环境变量
|
||||
@@ -181,73 +224,156 @@ func setupConfigFile(v *viper.Viper, configPath string) error {
|
||||
v.SetConfigFile(configPath)
|
||||
v.SetConfigType("yaml")
|
||||
|
||||
// 尝试读取配置文件,如果不存在则忽略
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
// 配置文件不存在,创建默认配置文件
|
||||
if err := v.SafeWriteConfig(); err != nil {
|
||||
// 忽略写入错误(可能目录已存在等)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadConfig loads config from YAML file, creates default if not exists
|
||||
func LoadConfig() (*Config, error) {
|
||||
configPath, err := GetConfigPath()
|
||||
if err != nil {
|
||||
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
return LoadConfigFromPath(configPath)
|
||||
type ConfigMetadata struct {
|
||||
ConfigPath string
|
||||
}
|
||||
|
||||
// LoadConfigFromPath 从指定路径加载配置
|
||||
func LoadConfigFromPath(configPath string) (*Config, error) {
|
||||
// 1. 创建 Viper 实例
|
||||
type loadOptions struct {
|
||||
configPathOverride string
|
||||
useCLI bool
|
||||
useEnv bool
|
||||
useConfigFlag bool
|
||||
}
|
||||
|
||||
// resolveConfigPath 根据 loadOptions 解析 CLI 参数并返回最终配置文件路径
|
||||
func resolveConfigPath(v *viper.Viper, opts loadOptions) (string, error) {
|
||||
configPath := opts.configPathOverride
|
||||
|
||||
if !opts.useCLI && !opts.useConfigFlag {
|
||||
return configPath, nil
|
||||
}
|
||||
|
||||
flagSet := pflag.NewFlagSet("config", pflag.ContinueOnError)
|
||||
if opts.useConfigFlag {
|
||||
flagSet.String("config", opts.configPathOverride, "配置文件路径")
|
||||
}
|
||||
if opts.useCLI {
|
||||
setupFlags(v, flagSet)
|
||||
}
|
||||
|
||||
if err := flagSet.Parse(os.Args[1:]); err != nil {
|
||||
return "", appErrors.Wrap(appErrors.ErrInvalidRequest, err)
|
||||
}
|
||||
|
||||
if opts.useConfigFlag {
|
||||
if f, err := flagSet.GetString("config"); err == nil && f != "" {
|
||||
configPath = f
|
||||
}
|
||||
}
|
||||
|
||||
return configPath, nil
|
||||
}
|
||||
|
||||
func loadConfig(opts loadOptions) (*Config, error) {
|
||||
cfg, _, err := loadConfigWithMetadata(opts)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
func loadConfigWithMetadata(opts loadOptions) (*Config, ConfigMetadata, error) {
|
||||
v := viper.New()
|
||||
|
||||
// 2. 定义 CLI 参数
|
||||
flagSet := pflag.NewFlagSet("config", pflag.ContinueOnError)
|
||||
flagSet.String("config", configPath, "配置文件路径")
|
||||
setupFlags(v, flagSet)
|
||||
|
||||
// 3. 解析 CLI 参数(忽略错误,因为可能没有参数)
|
||||
flagSet.Parse(os.Args[1:])
|
||||
|
||||
// 4. 获取配置文件路径(可能被 --config 参数覆盖)
|
||||
if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" {
|
||||
configPath = configPathFlag
|
||||
}
|
||||
|
||||
// 5. 设置默认值
|
||||
setupDefaults(v)
|
||||
|
||||
// 6. 绑定环境变量
|
||||
setupEnv(v)
|
||||
|
||||
// 7. 读取配置文件
|
||||
if err := setupConfigFile(v, configPath); err != nil {
|
||||
return nil, err
|
||||
configPath, err := resolveConfigPath(v, opts)
|
||||
if err != nil {
|
||||
return nil, ConfigMetadata{}, err
|
||||
}
|
||||
|
||||
if opts.useEnv {
|
||||
setupEnv(v)
|
||||
}
|
||||
|
||||
if err := setupConfigFile(v, configPath); err != nil {
|
||||
return nil, ConfigMetadata{}, err
|
||||
}
|
||||
|
||||
// 8. 反序列化到结构体
|
||||
cfg := &Config{}
|
||||
if err := v.Unmarshal(cfg, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
|
||||
mapstructure.StringToTimeDurationHookFunc(),
|
||||
mapstructure.StringToSliceHookFunc(","),
|
||||
))); err != nil {
|
||||
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
return nil, ConfigMetadata{}, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
// 9. 验证配置
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
return nil, ConfigMetadata{}, err
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
return cfg, ConfigMetadata{ConfigPath: configPath}, nil
|
||||
}
|
||||
|
||||
// LoadServerConfig 为 server 入口加载配置,支持 CLI 参数、环境变量和 --config
|
||||
func LoadServerConfig() (*Config, error) {
|
||||
cfg, _, err := LoadServerConfigWithMetadata()
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
func LoadServerConfigWithMetadata() (*Config, ConfigMetadata, error) {
|
||||
configPath, err := GetConfigPath()
|
||||
if err != nil {
|
||||
return nil, ConfigMetadata{}, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
return loadConfigWithMetadata(loadOptions{
|
||||
configPathOverride: configPath,
|
||||
useCLI: true,
|
||||
useEnv: true,
|
||||
useConfigFlag: true,
|
||||
})
|
||||
}
|
||||
|
||||
// LoadDesktopConfig 为 desktop 入口加载配置,固定使用默认配置文件,不支持 CLI、环境变量和 --config
|
||||
func LoadDesktopConfig() (*Config, error) {
|
||||
cfg, _, err := LoadDesktopConfigWithMetadata()
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
func LoadDesktopConfigWithMetadata() (*Config, ConfigMetadata, error) {
|
||||
configPath, err := GetConfigPath()
|
||||
if err != nil {
|
||||
return nil, ConfigMetadata{}, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
return loadConfigWithMetadata(loadOptions{
|
||||
configPathOverride: configPath,
|
||||
useCLI: false,
|
||||
useEnv: false,
|
||||
useConfigFlag: false,
|
||||
})
|
||||
}
|
||||
|
||||
// LoadConfig loads config from YAML file.
|
||||
// 向后兼容,等同于 LoadServerConfig。
|
||||
func LoadConfig() (*Config, error) {
|
||||
return LoadServerConfig()
|
||||
}
|
||||
|
||||
// LoadConfigFromPath 从指定路径加载配置。
|
||||
// 保留向后兼容,沿用 server 语义(支持 CLI、env 和 --config 覆盖)。
|
||||
func LoadConfigFromPath(configPath string) (*Config, error) {
|
||||
return loadConfig(loadOptions{
|
||||
configPathOverride: configPath,
|
||||
useCLI: true,
|
||||
useEnv: true,
|
||||
useConfigFlag: true,
|
||||
})
|
||||
}
|
||||
|
||||
// LoadDesktopConfigAtPath 从指定路径以 desktop 语义加载配置(仅配置文件和默认值),用于测试场景。
|
||||
func LoadDesktopConfigAtPath(configPath string) (*Config, error) {
|
||||
return loadConfig(loadOptions{
|
||||
configPathOverride: configPath,
|
||||
useCLI: false,
|
||||
useEnv: false,
|
||||
useConfigFlag: false,
|
||||
})
|
||||
}
|
||||
|
||||
// SaveConfig saves config to YAML file
|
||||
@@ -256,19 +382,21 @@ func SaveConfig(cfg *Config) error {
|
||||
if err != nil {
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
return SaveConfigToPath(cfg, configPath)
|
||||
}
|
||||
|
||||
func SaveConfigToPath(cfg *Config, configPath string) error {
|
||||
data, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(configPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
return os.WriteFile(configPath, data, 0644)
|
||||
return os.WriteFile(configPath, data, 0o600)
|
||||
}
|
||||
|
||||
// Validate validates the config
|
||||
@@ -281,16 +409,24 @@ func (c *Config) Validate() error {
|
||||
}
|
||||
|
||||
// PrintSummary 打印配置摘要
|
||||
func (c *Config) PrintSummary() {
|
||||
fmt.Println("\nAI Gateway 启动配置")
|
||||
fmt.Println("==================")
|
||||
fmt.Printf("服务器端口: %d\n", c.Server.Port)
|
||||
fmt.Printf("数据库路径: %s\n", c.Database.Path)
|
||||
fmt.Printf("日志级别: %s\n", c.Log.Level)
|
||||
fmt.Println("\n配置来源:")
|
||||
configPath, _ := GetConfigPath()
|
||||
fmt.Printf(" 配置文件: %s\n", configPath)
|
||||
fmt.Println(" 环境变量: 待统计")
|
||||
fmt.Println(" CLI 参数: 待统计")
|
||||
fmt.Println()
|
||||
func (c *Config) PrintSummary(logger *zap.Logger) {
|
||||
logger.Info("AI Gateway 启动配置",
|
||||
zap.Int("server_port", c.Server.Port),
|
||||
zap.String("database_driver", c.Database.Driver),
|
||||
zap.String("log_level", c.Log.Level),
|
||||
)
|
||||
|
||||
if c.Database.Driver == "mysql" {
|
||||
logger.Info("数据库配置",
|
||||
zap.String("driver", "mysql"),
|
||||
zap.String("host", c.Database.Host),
|
||||
zap.Int("port", c.Database.Port),
|
||||
zap.String("database", c.Database.DBName),
|
||||
)
|
||||
} else {
|
||||
logger.Info("数据库配置",
|
||||
zap.String("driver", "sqlite"),
|
||||
zap.String("path", c.Database.Path),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
94
backend/internal/config/config_metadata_test.go
Normal file
94
backend/internal/config/config_metadata_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestLoadDesktopConfigAtPath_WithMetadata(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.yaml")
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.Server.Port = 8888
|
||||
data, err := yaml.Marshal(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(configPath, data, 0o600))
|
||||
|
||||
loaded, meta, err := loadConfigWithMetadata(loadOptions{
|
||||
configPathOverride: configPath,
|
||||
useCLI: false,
|
||||
useEnv: false,
|
||||
useConfigFlag: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 8888, loaded.Server.Port)
|
||||
assert.Equal(t, configPath, meta.ConfigPath)
|
||||
}
|
||||
|
||||
func TestSaveConfigToPath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "sub", "config.yaml")
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.Server.Port = 7777
|
||||
|
||||
err := SaveConfigToPath(cfg, configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(data), "7777")
|
||||
}
|
||||
|
||||
func TestSaveConfigToPath_InvalidDir(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
err := SaveConfigToPath(cfg, "/dev/null/impossible/config.yaml")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDurationConversion(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
dto := configToDTO(cfg)
|
||||
|
||||
parsed, err := time.ParseDuration(dto.Server.ReadTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cfg.Server.ReadTimeout, parsed)
|
||||
|
||||
parsed, err = time.ParseDuration(dto.Database.ConnMaxLifetime)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cfg.Database.ConnMaxLifetime, parsed)
|
||||
}
|
||||
|
||||
func configToDTO(c *Config) struct {
|
||||
Server struct {
|
||||
Port int `json:"port"`
|
||||
ReadTimeout string `json:"read_timeout"`
|
||||
WriteTimeout string `json:"write_timeout"`
|
||||
}
|
||||
Database struct {
|
||||
ConnMaxLifetime string `json:"conn_max_lifetime"`
|
||||
}
|
||||
} {
|
||||
var result struct {
|
||||
Server struct {
|
||||
Port int `json:"port"`
|
||||
ReadTimeout string `json:"read_timeout"`
|
||||
WriteTimeout string `json:"write_timeout"`
|
||||
}
|
||||
Database struct {
|
||||
ConnMaxLifetime string `json:"conn_max_lifetime"`
|
||||
}
|
||||
}
|
||||
result.Server.Port = c.Server.Port
|
||||
result.Server.ReadTimeout = c.Server.ReadTimeout.String()
|
||||
result.Server.WriteTimeout = c.Server.WriteTimeout.String()
|
||||
result.Database.ConnMaxLifetime = c.Database.ConnMaxLifetime.String()
|
||||
return result
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -19,6 +20,12 @@ func TestDefaultConfig(t *testing.T) {
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
|
||||
|
||||
assert.Equal(t, "sqlite", cfg.Database.Driver)
|
||||
assert.Equal(t, "", cfg.Database.Host)
|
||||
assert.Equal(t, 3306, cfg.Database.Port)
|
||||
assert.Equal(t, "", cfg.Database.User)
|
||||
assert.Equal(t, "", cfg.Database.Password)
|
||||
assert.Equal(t, "nex", cfg.Database.DBName)
|
||||
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
|
||||
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
|
||||
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
|
||||
@@ -86,11 +93,76 @@ func TestConfig_Validate(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "数据库路径为空无效",
|
||||
name: "SQLite模式路径为空无效",
|
||||
modify: func(c *Config) { c.Database.Path = "" },
|
||||
wantErr: true,
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "driver值不合法",
|
||||
modify: func(c *Config) { c.Database.Driver = "postgres" },
|
||||
wantErr: true,
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "MySQL配置有效",
|
||||
modify: func(c *Config) {
|
||||
c.Database.Driver = "mysql"
|
||||
c.Database.Host = "localhost"
|
||||
c.Database.Port = 3306
|
||||
c.Database.User = "root"
|
||||
c.Database.DBName = "nex"
|
||||
c.Database.Path = ""
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "MySQL模式host为空无效",
|
||||
modify: func(c *Config) {
|
||||
c.Database.Driver = "mysql"
|
||||
c.Database.Host = ""
|
||||
c.Database.User = "root"
|
||||
c.Database.DBName = "nex"
|
||||
c.Database.Path = ""
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "MySQL模式user为空无效",
|
||||
modify: func(c *Config) {
|
||||
c.Database.Driver = "mysql"
|
||||
c.Database.Host = "localhost"
|
||||
c.Database.User = ""
|
||||
c.Database.DBName = "nex"
|
||||
c.Database.Path = ""
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "MySQL模式dbname为空无效",
|
||||
modify: func(c *Config) {
|
||||
c.Database.Driver = "mysql"
|
||||
c.Database.Host = "localhost"
|
||||
c.Database.User = "root"
|
||||
c.Database.DBName = ""
|
||||
c.Database.Path = ""
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "MySQL模式忽略path字段",
|
||||
modify: func(c *Config) {
|
||||
c.Database.Driver = "mysql"
|
||||
c.Database.Host = "localhost"
|
||||
c.Database.User = "root"
|
||||
c.Database.DBName = "nex"
|
||||
c.Database.Path = ""
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -100,7 +172,9 @@ func TestConfig_Validate(t *testing.T) {
|
||||
err := cfg.Validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -140,7 +214,10 @@ func TestSaveAndLoadConfig(t *testing.T) {
|
||||
WriteTimeout: 20 * time.Second,
|
||||
},
|
||||
Database: DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: filepath.Join(dir, "test.db"),
|
||||
Port: 3306,
|
||||
DBName: "nex",
|
||||
MaxIdleConns: 5,
|
||||
MaxOpenConns: 50,
|
||||
ConnMaxLifetime: 30 * time.Minute,
|
||||
@@ -159,7 +236,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
|
||||
configPath := filepath.Join(dir, "config.yaml")
|
||||
data, err := yaml.Marshal(cfg)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(configPath, data, 0644)
|
||||
err = os.WriteFile(configPath, data, 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 加载配置
|
||||
@@ -210,6 +287,9 @@ func TestConfigPriority(t *testing.T) {
|
||||
assert.Equal(t, 9826, cfg.Server.Port)
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
|
||||
assert.Equal(t, "sqlite", cfg.Database.Driver)
|
||||
assert.Equal(t, 3306, cfg.Database.Port)
|
||||
assert.Equal(t, "nex", cfg.Database.DBName)
|
||||
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
|
||||
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
|
||||
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
|
||||
@@ -222,13 +302,21 @@ func TestConfigPriority(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPrintSummary(t *testing.T) {
|
||||
// 测试配置摘要输出
|
||||
t.Run("打印配置摘要", func(t *testing.T) {
|
||||
t.Run("SQLite模式摘要", func(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
// PrintSummary 只是打印,不会返回错误
|
||||
// 这里主要验证不会 panic
|
||||
assert.NotPanics(t, func() {
|
||||
cfg.PrintSummary()
|
||||
cfg.PrintSummary(zap.NewNop())
|
||||
})
|
||||
})
|
||||
t.Run("MySQL模式摘要", func(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
cfg.Database.Driver = "mysql"
|
||||
cfg.Database.Host = "db.example.com"
|
||||
cfg.Database.Port = 3306
|
||||
cfg.Database.User = "nex"
|
||||
cfg.Database.DBName = "nex"
|
||||
assert.NotPanics(t, func() {
|
||||
cfg.PrintSummary(zap.NewNop())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,22 +6,22 @@ import (
|
||||
|
||||
// Provider 供应商模型
|
||||
type Provider struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
APIKey string `gorm:"not null" json:"api_key"`
|
||||
BaseURL string `gorm:"not null" json:"base_url"`
|
||||
Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
APIKey string `gorm:"not null" json:"api_key"`
|
||||
BaseURL string `gorm:"not null" json:"base_url"`
|
||||
Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
|
||||
}
|
||||
|
||||
// Model 模型配置
|
||||
// Model 模型配置(id 为 UUID 自动生成,UNIQUE(provider_id, model_name))
|
||||
type Model struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
ProviderID string `gorm:"not null;index" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index" json:"model_name"`
|
||||
ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"model_name"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
@@ -29,8 +29,8 @@ type Model struct {
|
||||
// UsageStats 用量统计
|
||||
type UsageStats struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
ProviderID string `gorm:"not null;index" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index" json:"model_name"`
|
||||
ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"model_name"`
|
||||
RequestCount int `gorm:"default:0" json:"request_count"`
|
||||
Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"`
|
||||
}
|
||||
@@ -47,12 +47,3 @@ func (Model) TableName() string {
|
||||
func (UsageStats) TableName() string {
|
||||
return "usage_stats"
|
||||
}
|
||||
|
||||
// MaskAPIKey 掩码 API Key(仅显示最后 4 个字符)
|
||||
func (p *Provider) MaskAPIKey() {
|
||||
if len(p.APIKey) > 4 {
|
||||
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
|
||||
} else {
|
||||
p.APIKey = "***"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +40,12 @@ type ProtocolAdapter interface {
|
||||
EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error)
|
||||
DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error)
|
||||
EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error)
|
||||
|
||||
// 统一模型 ID 相关方法
|
||||
ExtractUnifiedModelID(nativePath string) (string, error)
|
||||
ExtractModelName(body []byte, ifaceType InterfaceType) (string, error)
|
||||
RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
|
||||
RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
|
||||
}
|
||||
|
||||
// AdapterRegistry 适配器注册表接口
|
||||
|
||||
@@ -2,6 +2,7 @@ package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
@@ -39,13 +40,13 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
|
||||
}
|
||||
}
|
||||
|
||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id})
|
||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /)
|
||||
func isModelInfoPath(path string) bool {
|
||||
if !strings.HasPrefix(path, "/v1/models/") {
|
||||
return false
|
||||
}
|
||||
suffix := path[len("/v1/models/"):]
|
||||
return suffix != "" && !strings.Contains(suffix, "/")
|
||||
return suffix != ""
|
||||
}
|
||||
|
||||
// BuildUrl 根据接口类型构建 URL
|
||||
@@ -140,7 +141,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
||||
Message: err.Message,
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(errMsg)
|
||||
body, marshalErr := json.Marshal(errMsg)
|
||||
if marshalErr != nil {
|
||||
return []byte(`{"type":"error","error":{"type":"internal_error","message":"internal error"}}`), statusCode
|
||||
}
|
||||
return body, statusCode
|
||||
}
|
||||
|
||||
@@ -203,3 +207,82 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe
|
||||
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||
}
|
||||
|
||||
// ExtractUnifiedModelID 从路径中提取统一模型 ID(/v1/models/{provider_id}/{model_name})
|
||||
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||
if !strings.HasPrefix(nativePath, "/v1/models/") {
|
||||
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
|
||||
}
|
||||
suffix := nativePath[len("/v1/models/"):]
|
||||
if suffix == "" {
|
||||
return "", fmt.Errorf("路径缺少模型 ID")
|
||||
}
|
||||
return suffix, nil
|
||||
}
|
||||
|
||||
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
|
||||
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
raw, exists := m["model"]
|
||||
if !exists {
|
||||
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
|
||||
}
|
||||
var current string
|
||||
if err := json.Unmarshal(raw, ¤t); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
}
|
||||
return current, rewriteFunc, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractModelName 从请求体中提取 model 值
|
||||
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
|
||||
model, _, err := locateModelFieldInRequest(body, ifaceType)
|
||||
return model, err
|
||||
}
|
||||
|
||||
// RewriteRequestModelName 最小化改写请求体中的 model 字段
|
||||
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rewriteFunc(newModel)
|
||||
}
|
||||
|
||||
// RewriteResponseModelName 最小化改写响应体中的 model 字段
|
||||
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
// Chat 响应必须有 model 字段,存在则改写,不存在则添加
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
@@ -48,14 +49,36 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_APIReferenceNativePaths(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
// docs/api_reference/anthropic defines messages and models under /v1.
|
||||
tests := []struct {
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"/v1/messages", conversion.InterfaceTypeChat},
|
||||
{"/v1/models", conversion.InterfaceTypeModels},
|
||||
{"/v1/models/claude-sonnet-4-5", conversion.InterfaceTypeModelInfo},
|
||||
{"/messages", conversion.InterfaceTypePassthrough},
|
||||
{"/models", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nativePath string
|
||||
name string
|
||||
nativePath string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected string
|
||||
expected string
|
||||
}{
|
||||
{"聊天", "/v1/messages", conversion.InterfaceTypeChat, "/v1/messages"},
|
||||
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
|
||||
@@ -102,9 +125,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
name string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected bool
|
||||
expected bool
|
||||
}{
|
||||
{"聊天", conversion.InterfaceTypeChat, true},
|
||||
{"模型", conversion.InterfaceTypeModels, true},
|
||||
@@ -141,8 +164,8 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
|
||||
t.Run("解码嵌入请求", func(t *testing.T) {
|
||||
_, err := a.DecodeEmbeddingRequest([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
@@ -150,24 +173,24 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.True(t, errors.As(err, &convErr))
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("解码嵌入响应", func(t *testing.T) {
|
||||
_, err := a.DecodeEmbeddingResponse([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码嵌入响应", func(t *testing.T) {
|
||||
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
}
|
||||
@@ -178,8 +201,8 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
|
||||
t.Run("解码重排序请求", func(t *testing.T) {
|
||||
_, err := a.DecodeRerankRequest([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
@@ -187,24 +210,24 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("解码重排序响应", func(t *testing.T) {
|
||||
_, err := a.DecodeRerankResponse([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码重排序响应", func(t *testing.T) {
|
||||
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
263
backend/internal/conversion/anthropic/adapter_unified_test.go
Normal file
263
backend/internal/conversion/anthropic/adapter_unified_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractUnifiedModelID
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractUnifiedModelID(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("standard_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/anthropic/claude-3")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3", id)
|
||||
})
|
||||
|
||||
t.Run("multi_segment_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/some/deep/nested/model")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "some/deep/nested/model", id)
|
||||
})
|
||||
|
||||
t.Run("single_segment", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/claude-3")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-3", id)
|
||||
})
|
||||
|
||||
t.Run("non_model_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/messages")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty_suffix", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models/")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("models_list_no_slash", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unrelated_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/other")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName (Chat only for Anthropic)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3", model)
|
||||
})
|
||||
|
||||
t.Run("chat_with_max_tokens", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3-opus","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3-opus", model)
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type_embedding", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3"}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type_rerank", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3"}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteRequestModelName (Chat only for Anthropic)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteRequestModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "claude-3", m["model"])
|
||||
|
||||
msgs, ok := m["messages"]
|
||||
require.True(t, ok)
|
||||
msgsArr, ok := msgs.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgsArr, 0)
|
||||
})
|
||||
|
||||
t.Run("preserves_unknown_fields", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3","max_tokens":1024,"temperature":0.7}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "claude-3", m["model"])
|
||||
assert.Equal(t, 0.7, m["temperature"])
|
||||
|
||||
// max_tokens is encoded as float in JSON numbers
|
||||
maxTokens, ok := m["max_tokens"]
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, float64(1024), maxTokens)
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3"}`)
|
||||
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeEmbeddings)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteResponseModelName (Chat only for Anthropic)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteResponseModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3","content":[],"stop_reason":"end_turn"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "anthropic/claude-3", m["model"])
|
||||
|
||||
// other fields preserved
|
||||
_, hasContent := m["content"]
|
||||
assert.True(t, hasContent)
|
||||
assert.Equal(t, "end_turn", m["stop_reason"])
|
||||
})
|
||||
|
||||
t.Run("chat_without_model_field_adds_it", func(t *testing.T) {
|
||||
body := []byte(`{"content":[],"stop_reason":"end_turn"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "anthropic/claude-3", m["model"])
|
||||
})
|
||||
|
||||
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypePassthrough)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(body), string(rewritten))
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName and RewriteRequest consistency
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"anthropic/claude-3","messages":[],"max_tokens":1024}`)
|
||||
|
||||
// Extract the unified model ID from the body
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3", extracted)
|
||||
|
||||
// Rewrite to the native model name
|
||||
rewritten, err := a.RewriteRequestModelName(original, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract again from the rewritten body to verify the same location was targeted
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-3", afterRewrite)
|
||||
|
||||
// Verify other fields are preserved
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, float64(1024), m["max_tokens"])
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isModelInfoPath (additional unified model ID cases)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"simple_model_id", "/v1/models/claude-3", true},
|
||||
{"unified_model_id_with_slash", "/v1/models/anthropic/claude-3", true},
|
||||
{"models_list", "/v1/models", false},
|
||||
{"models_list_trailing_slash", "/v1/models/", false},
|
||||
{"messages_path", "/v1/messages", false},
|
||||
{"deeply_nested", "/v1/models/org/workspace/claude-3-opus", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -28,7 +28,10 @@ func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
|
||||
|
||||
var canonicalMsgs []canonical.CanonicalMessage
|
||||
for _, msg := range req.Messages {
|
||||
decoded := decodeMessage(msg)
|
||||
decoded, err := decodeMessage(msg)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析消息内容失败").WithCause(err)
|
||||
}
|
||||
canonicalMsgs = append(canonicalMsgs, decoded...)
|
||||
}
|
||||
|
||||
@@ -94,10 +97,13 @@ func decodeSystem(system any) any {
|
||||
}
|
||||
|
||||
// decodeMessage 解码 Anthropic 消息
|
||||
func decodeMessage(msg Message) []canonical.CanonicalMessage {
|
||||
func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
blocks := decodeContentBlocks(msg.Content)
|
||||
blocks, err := decodeContentBlocks(msg.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var toolResults []canonical.ContentBlock
|
||||
var others []canonical.ContentBlock
|
||||
for _, b := range blocks {
|
||||
@@ -117,58 +123,83 @@ func decodeMessage(msg Message) []canonical.CanonicalMessage {
|
||||
if len(result) == 0 {
|
||||
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
|
||||
}
|
||||
return result
|
||||
return result, nil
|
||||
|
||||
case "assistant":
|
||||
blocks := decodeContentBlocks(msg.Content)
|
||||
blocks, err := decodeContentBlocks(msg.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// decodeContentBlocks 解码内容块列表
|
||||
func decodeContentBlocks(content any) []canonical.ContentBlock {
|
||||
func decodeContentBlocks(content any) ([]canonical.ContentBlock, error) {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(v)}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(v)}, nil
|
||||
case []any:
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
block := decodeSingleContentBlock(m)
|
||||
block, err := decodeSingleContentBlock(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if block != nil {
|
||||
blocks = append(blocks, *block)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(blocks) > 0 {
|
||||
return blocks
|
||||
return blocks, nil
|
||||
}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
|
||||
case nil:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
|
||||
default:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// decodeSingleContentBlock 解码单个内容块
|
||||
func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
||||
t, _ := m["type"].(string)
|
||||
func decodeSingleContentBlock(m map[string]any) (*canonical.ContentBlock, error) {
|
||||
t, ok := m["type"].(string)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
return &canonical.ContentBlock{Type: "text", Text: text}
|
||||
text, ok := m["text"].(string)
|
||||
if !ok {
|
||||
text = ""
|
||||
}
|
||||
return &canonical.ContentBlock{Type: "text", Text: text}, nil
|
||||
case "tool_use":
|
||||
id, _ := m["id"].(string)
|
||||
name, _ := m["name"].(string)
|
||||
input, _ := json.Marshal(m["input"])
|
||||
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}
|
||||
id, ok := m["id"].(string)
|
||||
if !ok {
|
||||
id = ""
|
||||
}
|
||||
name, ok := m["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
input, err := json.Marshal(m["input"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}, nil
|
||||
case "tool_result":
|
||||
toolUseID, _ := m["tool_use_id"].(string)
|
||||
toolUseID, ok := m["tool_use_id"].(string)
|
||||
if !ok {
|
||||
toolUseID = ""
|
||||
}
|
||||
isErr := false
|
||||
if ie, ok := m["is_error"].(bool); ok {
|
||||
isErr = ie
|
||||
@@ -179,7 +210,11 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
||||
case string:
|
||||
content = json.RawMessage(fmt.Sprintf("%q", cv))
|
||||
default:
|
||||
content, _ = json.Marshal(cv)
|
||||
encoded, err := json.Marshal(cv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content = encoded
|
||||
}
|
||||
} else {
|
||||
content = json.RawMessage(`""`)
|
||||
@@ -189,15 +224,18 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
||||
ToolUseID: toolUseID,
|
||||
Content: content,
|
||||
IsError: &isErr,
|
||||
}
|
||||
}, nil
|
||||
case "thinking":
|
||||
thinking, _ := m["thinking"].(string)
|
||||
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}
|
||||
thinking, ok := m["thinking"].(string)
|
||||
if !ok {
|
||||
thinking = ""
|
||||
}
|
||||
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}, nil
|
||||
case "redacted_thinking":
|
||||
// 丢弃
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// decodeTools 解码工具定义
|
||||
@@ -232,7 +270,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
case map[string]any:
|
||||
t, _ := v["type"].(string)
|
||||
t, ok := v["type"].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch t {
|
||||
case "auto":
|
||||
return canonical.NewToolChoiceAuto()
|
||||
@@ -241,7 +282,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
case "any":
|
||||
return canonical.NewToolChoiceAny()
|
||||
case "tool":
|
||||
name, _ := v["name"].(string)
|
||||
name, ok := v["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,7 +182,7 @@ func encodeContentBlocks(blocks []canonical.ContentBlock) []map[string]any {
|
||||
result = append(result, m)
|
||||
case "tool_result":
|
||||
m := map[string]any{
|
||||
"type": "tool_result",
|
||||
"type": "tool_result",
|
||||
"tool_use_id": b.ToolUseID,
|
||||
}
|
||||
if b.Content != nil {
|
||||
@@ -335,11 +335,11 @@ func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"id": resp.ID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": resp.Model,
|
||||
"content": blocks,
|
||||
"id": resp.ID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": resp.Model,
|
||||
"content": blocks,
|
||||
"stop_reason": sr,
|
||||
"stop_sequence": nil,
|
||||
"usage": usage,
|
||||
|
||||
@@ -33,7 +33,8 @@ func TestEncodeRequest_Basic(t *testing.T) {
|
||||
assert.Equal(t, true, result["stream"])
|
||||
assert.Equal(t, float64(1024), result["max_tokens"])
|
||||
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 1)
|
||||
}
|
||||
|
||||
@@ -55,17 +56,20 @@ func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
|
||||
// tool 消息应被合并到相邻 user 消息
|
||||
foundToolResult := false
|
||||
for _, m := range msgs {
|
||||
msgMap := m.(map[string]any)
|
||||
msgMap, ok := m.(map[string]any)
|
||||
require.True(t, ok)
|
||||
if msgMap["role"] == "user" {
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if ok {
|
||||
for _, c := range content {
|
||||
block := c.(map[string]any)
|
||||
block, ok := c.(map[string]any)
|
||||
require.True(t, ok)
|
||||
if block["type"] == "tool_result" {
|
||||
foundToolResult = true
|
||||
}
|
||||
@@ -93,8 +97,10 @@ func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
firstMsg := msgs[0].(map[string]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
firstMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "user", firstMsg["role"])
|
||||
}
|
||||
|
||||
@@ -140,9 +146,11 @@ func TestEncodeResponse_Basic(t *testing.T) {
|
||||
assert.Equal(t, "assistant", result["role"])
|
||||
assert.Equal(t, "end_turn", result["stop_reason"])
|
||||
|
||||
content := result["content"].([]any)
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, content, 1)
|
||||
block := content[0].(map[string]any)
|
||||
block, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "text", block["type"])
|
||||
assert.Equal(t, "你好", block["text"])
|
||||
}
|
||||
@@ -160,10 +168,12 @@ func TestEncodeModelsResponse(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
data := result["data"].([]any)
|
||||
data, ok := result["data"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, data, 1)
|
||||
|
||||
model := data[0].(map[string]any)
|
||||
model, ok := data[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "claude-3-opus", model["id"])
|
||||
// created 应为 RFC3339 格式
|
||||
createdAt, ok := model["created_at"].(string)
|
||||
@@ -280,11 +290,14 @@ func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 1)
|
||||
userMsg := msgs[0].(map[string]any)
|
||||
userMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "user", userMsg["role"])
|
||||
content := userMsg["content"].([]any)
|
||||
content, ok := userMsg["content"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, content, 2)
|
||||
}
|
||||
|
||||
@@ -321,7 +334,8 @@ func TestEncodeResponse_ReasoningTokens(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage := result["usage"].(map[string]any)
|
||||
usage, ok := result["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
_, hasReasoning := usage["reasoning_tokens"]
|
||||
assert.False(t, hasReasoning)
|
||||
}
|
||||
@@ -341,9 +355,11 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
content := result["content"].([]any)
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, content, 1)
|
||||
block := content[0].(map[string]any)
|
||||
block, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "tool_use", block["type"])
|
||||
assert.Equal(t, "tool_1", block["id"])
|
||||
assert.Equal(t, "search", block["name"])
|
||||
|
||||
@@ -28,7 +28,7 @@ func NewStreamDecoder() *StreamDecoder {
|
||||
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
data := rawChunk
|
||||
if len(d.utf8Remainder) > 0 {
|
||||
data = append(d.utf8Remainder, rawChunk...)
|
||||
data = append(append([]byte{}, d.utf8Remainder...), rawChunk...)
|
||||
d.utf8Remainder = nil
|
||||
}
|
||||
|
||||
@@ -50,9 +50,10 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
|
||||
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
switch {
|
||||
case strings.HasPrefix(line, "event: "):
|
||||
eventType = strings.TrimPrefix(line, "event: ")
|
||||
} else if strings.HasPrefix(line, "data: ") {
|
||||
case strings.HasPrefix(line, "data: "):
|
||||
eventData = strings.TrimPrefix(line, "data: ")
|
||||
if eventType != "" && eventData != "" {
|
||||
chunkEvents := d.processEvent(eventType, []byte(eventData))
|
||||
@@ -60,8 +61,8 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
|
||||
}
|
||||
eventType = ""
|
||||
eventData = ""
|
||||
} else if line == "" {
|
||||
// SSE 事件分隔符
|
||||
case line == "":
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,7 +136,7 @@ func (d *StreamDecoder) processMessageStart(data []byte) []canonical.CanonicalSt
|
||||
// processContentBlockStart 处理内容块开始事件
|
||||
func (d *StreamDecoder) processContentBlockStart(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Index int `json:"index"`
|
||||
Index int `json:"index"`
|
||||
ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
|
||||
@@ -47,23 +47,23 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
|
||||
checkValue string
|
||||
}{
|
||||
{
|
||||
name: "text_delta",
|
||||
deltaType: "text_delta",
|
||||
deltaData: map[string]any{"type": "text_delta", "text": "你好"},
|
||||
name: "text_delta",
|
||||
deltaType: "text_delta",
|
||||
deltaData: map[string]any{"type": "text_delta", "text": "你好"},
|
||||
checkField: "text",
|
||||
checkValue: "你好",
|
||||
},
|
||||
{
|
||||
name: "input_json_delta",
|
||||
deltaType: "input_json_delta",
|
||||
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
|
||||
name: "input_json_delta",
|
||||
deltaType: "input_json_delta",
|
||||
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
|
||||
checkField: "partial_json",
|
||||
checkValue: "{\"key\":",
|
||||
},
|
||||
{
|
||||
name: "thinking_delta",
|
||||
deltaType: "thinking_delta",
|
||||
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
|
||||
name: "thinking_delta",
|
||||
deltaType: "thinking_delta",
|
||||
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
|
||||
checkField: "thinking",
|
||||
checkValue: "思考中",
|
||||
},
|
||||
@@ -74,7 +74,7 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": tt.deltaData,
|
||||
"delta": tt.deltaData,
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_delta", payload)
|
||||
|
||||
@@ -298,7 +298,7 @@ func TestStreamDecoder_WebSearchToolResult_Suppressed(t *testing.T) {
|
||||
"type": "content_block_start",
|
||||
"index": 3,
|
||||
"content_block": map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "search_1",
|
||||
},
|
||||
}
|
||||
@@ -331,8 +331,8 @@ func TestStreamDecoder_CitationsDelta_Discarded(t *testing.T) {
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"type": "citations_delta",
|
||||
"citation": map[string]any{"title": "ref1"},
|
||||
"type": "citations_delta",
|
||||
"citation": map[string]any{"title": "ref1"},
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_delta", payload)
|
||||
@@ -466,7 +466,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
|
||||
},
|
||||
}
|
||||
deltaPayload1 := map[string]any{
|
||||
"type": "message_delta",
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{"stop_reason": "end_turn"},
|
||||
"usage": map[string]any{"output_tokens": 25},
|
||||
}
|
||||
@@ -478,7 +478,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
|
||||
assert.Equal(t, 25, events[0].Usage.OutputTokens)
|
||||
|
||||
deltaPayload2 := map[string]any{
|
||||
"type": "message_delta",
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{"stop_reason": "end_turn"},
|
||||
"usage": map[string]any{"output_tokens": 30},
|
||||
}
|
||||
|
||||
@@ -50,16 +50,24 @@ func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent)
|
||||
}
|
||||
if event.Message != nil {
|
||||
msg := map[string]any{
|
||||
"id": event.Message.ID,
|
||||
"model": event.Message.Model,
|
||||
"role": "assistant",
|
||||
"id": event.Message.ID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []any{},
|
||||
"model": event.Message.Model,
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
}
|
||||
if event.Message.Usage != nil {
|
||||
usage := map[string]any{
|
||||
msg["usage"] = map[string]any{
|
||||
"input_tokens": event.Message.Usage.InputTokens,
|
||||
"output_tokens": event.Message.Usage.OutputTokens,
|
||||
}
|
||||
msg["usage"] = usage
|
||||
} else {
|
||||
msg["usage"] = map[string]any{
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
}
|
||||
}
|
||||
payload["message"] = msg
|
||||
}
|
||||
@@ -147,6 +155,10 @@ func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent)
|
||||
payload["usage"] = map[string]any{
|
||||
"output_tokens": event.Usage.OutputTokens,
|
||||
}
|
||||
} else {
|
||||
payload["usage"] = map[string]any{
|
||||
"output_tokens": 0,
|
||||
}
|
||||
}
|
||||
return e.marshalEvent("message_delta", payload)
|
||||
}
|
||||
|
||||
@@ -21,8 +21,55 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: message_start\n"))
|
||||
assert.Contains(t, s, "data: ")
|
||||
assert.Contains(t, s, "msg_1")
|
||||
assert.Contains(t, s, "claude-3")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
msg, ok := payload["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "msg_1", msg["id"])
|
||||
assert.Equal(t, "message", msg["type"])
|
||||
assert.Equal(t, "assistant", msg["role"])
|
||||
assert.Equal(t, []any{}, msg["content"])
|
||||
assert.Equal(t, "claude-3", msg["model"])
|
||||
assert.Nil(t, msg["stop_reason"])
|
||||
assert.Nil(t, msg["stop_sequence"])
|
||||
|
||||
usage, okU := msg["usage"].(map[string]any)
|
||||
require.True(t, okU)
|
||||
assert.Equal(t, float64(0), usage["input_tokens"])
|
||||
assert.Equal(t, float64(0), usage["output_tokens"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageStart_WithUsage(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewMessageStartEventWithUsage("msg_2", "gpt-4", &canonical.CanonicalUsage{InputTokens: 100, OutputTokens: 50})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
msg, ok := payload["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
usage, okU := msg["usage"].(map[string]any)
|
||||
require.True(t, okU)
|
||||
assert.Equal(t, float64(100), usage["input_tokens"])
|
||||
assert.Equal(t, float64(50), usage["output_tokens"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockDelta(t *testing.T) {
|
||||
@@ -80,7 +127,8 @@ func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
cb, ok := payload["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "text", cb["type"])
|
||||
}
|
||||
|
||||
@@ -107,7 +155,8 @@ func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
cb, ok := payload["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "tool_use", cb["type"])
|
||||
assert.Equal(t, "toolu_1", cb["id"])
|
||||
assert.Equal(t, "search", cb["name"])
|
||||
@@ -131,7 +180,8 @@ func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
cb, ok := payload["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "thinking", cb["type"])
|
||||
}
|
||||
|
||||
@@ -173,8 +223,13 @@ func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
delta := payload["delta"].(map[string]any)
|
||||
delta, okd := payload["delta"].(map[string]any)
|
||||
require.True(t, okd)
|
||||
assert.Equal(t, "end_turn", delta["stop_reason"])
|
||||
|
||||
usage, oku := payload["usage"].(map[string]any)
|
||||
require.True(t, oku, "message_delta SHALL always include usage")
|
||||
assert.Equal(t, float64(0), usage["output_tokens"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
|
||||
@@ -199,7 +254,8 @@ func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
u := payload["usage"].(map[string]any)
|
||||
u, oku := payload["usage"].(map[string]any)
|
||||
require.True(t, oku)
|
||||
assert.Equal(t, float64(88), u["output_tokens"])
|
||||
}
|
||||
|
||||
|
||||
@@ -173,13 +173,15 @@ func TestDecodeMessage_UserWithOnlyToolResults(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDecodeContentBlocks_Nil(t *testing.T) {
|
||||
blocks := decodeContentBlocks(nil)
|
||||
blocks, err := decodeContentBlocks(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, blocks, 1)
|
||||
assert.Equal(t, "", blocks[0].Text)
|
||||
}
|
||||
|
||||
func TestDecodeContentBlocks_String(t *testing.T) {
|
||||
blocks := decodeContentBlocks("hello")
|
||||
blocks, err := decodeContentBlocks("hello")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, blocks, 1)
|
||||
assert.Equal(t, "hello", blocks[0].Text)
|
||||
}
|
||||
@@ -217,8 +219,10 @@ func TestEncodeToolChoice(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := encodeToolChoice(tt.choice)
|
||||
assert.Equal(t, tt.want["type"], result.(map[string]any)["type"])
|
||||
assert.Equal(t, tt.want["name"], result.(map[string]any)["name"])
|
||||
r, ok := result.(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.want["type"], r["type"])
|
||||
assert.Equal(t, tt.want["name"], r["name"])
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -315,12 +319,15 @@ func TestEncodeRequest_WithTools(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
tools := result["tools"].([]any)
|
||||
tools, okt := result["tools"].([]any)
|
||||
require.True(t, okt)
|
||||
assert.Len(t, tools, 1)
|
||||
tool := tools[0].(map[string]any)
|
||||
tool, okt2 := tools[0].(map[string]any)
|
||||
require.True(t, okt2)
|
||||
assert.Equal(t, "search", tool["name"])
|
||||
assert.Equal(t, "Search things", tool["description"])
|
||||
tc := result["tool_choice"].(map[string]any)
|
||||
tc, oktc := result["tool_choice"].(map[string]any)
|
||||
require.True(t, oktc)
|
||||
assert.Equal(t, "auto", tc["type"])
|
||||
}
|
||||
|
||||
@@ -354,9 +361,9 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheReadTokens: &cacheRead,
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheReadTokens: &cacheRead,
|
||||
CacheCreationTokens: &cacheCreation,
|
||||
},
|
||||
}
|
||||
@@ -366,7 +373,8 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage := result["usage"].(map[string]any)
|
||||
usage, oku := result["usage"].(map[string]any)
|
||||
require.True(t, oku)
|
||||
assert.Equal(t, float64(100), usage["input_tokens"])
|
||||
assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
|
||||
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])
|
||||
|
||||
@@ -6,22 +6,22 @@ import (
|
||||
|
||||
// MessagesRequest Anthropic Messages 请求
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Metadata *RequestMetadata `json:"metadata,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
OutputConfig *OutputConfig `json:"output_config,omitempty"`
|
||||
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
Container any `json:"container,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Metadata *RequestMetadata `json:"metadata,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
OutputConfig *OutputConfig `json:"output_config,omitempty"`
|
||||
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
Container any `json:"container,omitempty"`
|
||||
}
|
||||
|
||||
// RequestMetadata 请求元数据
|
||||
@@ -122,8 +122,8 @@ type ContentBlock struct {
|
||||
|
||||
// ResponseUsage 响应用量
|
||||
type ResponseUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"`
|
||||
CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"`
|
||||
}
|
||||
|
||||
@@ -38,8 +38,8 @@ type CanonicalEmbeddingResponse struct {
|
||||
|
||||
// EmbeddingData 嵌入数据项
|
||||
type EmbeddingData struct {
|
||||
Index int `json:"index"`
|
||||
Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串
|
||||
Index int `json:"index"`
|
||||
Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串
|
||||
}
|
||||
|
||||
// EmbeddingUsage 嵌入用量
|
||||
|
||||
@@ -18,17 +18,17 @@ const (
|
||||
type DeltaType string
|
||||
|
||||
const (
|
||||
DeltaTypeText DeltaType = "text_delta"
|
||||
DeltaTypeInputJSON DeltaType = "input_json_delta"
|
||||
DeltaTypeThinking DeltaType = "thinking_delta"
|
||||
DeltaTypeText DeltaType = "text_delta"
|
||||
DeltaTypeInputJSON DeltaType = "input_json_delta"
|
||||
DeltaTypeThinking DeltaType = "thinking_delta"
|
||||
)
|
||||
|
||||
// StreamDelta 流式增量联合体
|
||||
type StreamDelta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// StreamContentBlock 流式内容块联合体
|
||||
@@ -48,12 +48,12 @@ type CanonicalStreamEvent struct {
|
||||
Message *StreamMessage `json:"message,omitempty"`
|
||||
|
||||
// ContentBlockStartEvent / ContentBlockDeltaEvent / ContentBlockStopEvent
|
||||
Index *int `json:"index,omitempty"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *StreamContentBlock `json:"content_block,omitempty"`
|
||||
Delta *StreamDelta `json:"delta,omitempty"`
|
||||
Delta *StreamDelta `json:"delta,omitempty"`
|
||||
|
||||
// MessageDeltaEvent
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
Usage *CanonicalUsage `json:"usage,omitempty"`
|
||||
|
||||
// ErrorEvent
|
||||
|
||||
@@ -40,8 +40,8 @@ type ContentBlock struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// ToolUseBlock
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
|
||||
// ToolResultBlock
|
||||
@@ -138,43 +138,43 @@ type ThinkingConfig struct {
|
||||
|
||||
// OutputFormat 输出格式联合体
|
||||
type OutputFormat struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Schema json.RawMessage `json:"schema,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Schema json.RawMessage `json:"schema,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalRequest 规范请求
|
||||
type CanonicalRequest struct {
|
||||
Model string `json:"model"`
|
||||
System any `json:"system,omitempty"` // nil, string, or []SystemBlock
|
||||
Model string `json:"model"`
|
||||
System any `json:"system,omitempty"` // nil, string, or []SystemBlock
|
||||
Messages []CanonicalMessage `json:"messages"`
|
||||
Tools []CanonicalTool `json:"tools,omitempty"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Parameters RequestParameters `json:"parameters"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
OutputFormat *OutputFormat `json:"output_format,omitempty"`
|
||||
ParallelToolUse *bool `json:"parallel_tool_use,omitempty"`
|
||||
Tools []CanonicalTool `json:"tools,omitempty"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Parameters RequestParameters `json:"parameters"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
OutputFormat *OutputFormat `json:"output_format,omitempty"`
|
||||
ParallelToolUse *bool `json:"parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalUsage 规范用量
|
||||
type CanonicalUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheReadTokens *int `json:"cache_read_tokens,omitempty"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheReadTokens *int `json:"cache_read_tokens,omitempty"`
|
||||
CacheCreationTokens *int `json:"cache_creation_tokens,omitempty"`
|
||||
ReasoningTokens *int `json:"reasoning_tokens,omitempty"`
|
||||
ReasoningTokens *int `json:"reasoning_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalResponse 规范响应
|
||||
type CanonicalResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
Usage CanonicalUsage `json:"usage"`
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
Usage CanonicalUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// GetSystemString 获取系统消息字符串
|
||||
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
|
||||
func TestGetSystemString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
system any
|
||||
want string
|
||||
name string
|
||||
system any
|
||||
want string
|
||||
}{
|
||||
{"string", "hello", "hello"},
|
||||
{"nil", nil, ""},
|
||||
@@ -97,11 +97,11 @@ func TestCanonicalRequest_RoundTrip(t *testing.T) {
|
||||
func TestCanonicalResponse_RoundTrip(t *testing.T) {
|
||||
sr := StopReasonEndTurn
|
||||
resp := &CanonicalResponse{
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []ContentBlock{NewTextBlock("hello")},
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []ContentBlock{NewTextBlock("hello")},
|
||||
StopReason: &sr,
|
||||
Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
|
||||
@@ -3,10 +3,14 @@ package conversion
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// HTTPRequestSpec HTTP 请求规格
|
||||
@@ -33,13 +37,10 @@ type ConversionEngine struct {
|
||||
|
||||
// NewConversionEngine 创建转换引擎
|
||||
func NewConversionEngine(registry AdapterRegistry, logger *zap.Logger) *ConversionEngine {
|
||||
if logger == nil {
|
||||
logger = zap.L()
|
||||
}
|
||||
return &ConversionEngine{
|
||||
registry: registry,
|
||||
middlewareChain: NewMiddlewareChain(),
|
||||
logger: logger,
|
||||
logger: pkglogger.WithModule(logger, "conversion.engine"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,18 +73,39 @@ func (e *ConversionEngine) IsPassthrough(clientProtocol, providerProtocol string
|
||||
|
||||
// ConvertHttpRequest 转换 HTTP 请求
|
||||
func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) {
|
||||
nativePath := spec.URL
|
||||
nativePath, rawQuery := splitRequestPath(spec.URL)
|
||||
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
providerAdapter, err := e.registry.Get(providerProtocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Smart Passthrough: 同协议时最小化改写 model 字段
|
||||
interfaceType := providerAdapter.DetectInterfaceType(nativePath)
|
||||
rewrittenBody := spec.Body
|
||||
|
||||
// 对于 Chat/Embedding/Rerank 接口,改写请求体中的 model 字段
|
||||
if interfaceType == InterfaceTypeChat || interfaceType == InterfaceTypeEmbeddings || interfaceType == InterfaceTypeRerank {
|
||||
if len(spec.Body) > 0 && provider.ModelName != "" {
|
||||
rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType)
|
||||
if err != nil {
|
||||
e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体",
|
||||
zap.Error(err),
|
||||
zap.String("interface", string(interfaceType)))
|
||||
rewrittenBody = spec.Body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||
providerURL = appendRawQuery(providerURL, rawQuery)
|
||||
|
||||
return &HTTPRequestSpec{
|
||||
URL: provider.BaseURL + nativePath,
|
||||
URL: joinBaseURL(provider.BaseURL, providerURL),
|
||||
Method: spec.Method,
|
||||
Headers: providerAdapter.BuildHeaders(provider),
|
||||
Body: spec.Body,
|
||||
Body: rewrittenBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -97,7 +119,8 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
||||
}
|
||||
|
||||
interfaceType := clientAdapter.DetectInterfaceType(nativePath)
|
||||
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||
providerURL = appendRawQuery(providerURL, rawQuery)
|
||||
providerHeaders := providerAdapter.BuildHeaders(provider)
|
||||
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
|
||||
if err != nil {
|
||||
@@ -105,16 +128,34 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
||||
}
|
||||
|
||||
return &HTTPRequestSpec{
|
||||
URL: provider.BaseURL + providerUrl,
|
||||
URL: joinBaseURL(provider.BaseURL, providerURL),
|
||||
Method: spec.Method,
|
||||
Headers: providerHeaders,
|
||||
Body: providerBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ConvertHttpResponse 转换 HTTP 响应
|
||||
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType) (*HTTPResponseSpec, error) {
|
||||
// ConvertHttpResponse 转换 HTTP 响应,modelOverride 用于跨协议场景覆写 model 字段
|
||||
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType, modelOverride string) (*HTTPResponseSpec, error) {
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
// Smart Passthrough: 同协议时最小化改写 model 字段
|
||||
if modelOverride != "" && len(spec.Body) > 0 {
|
||||
adapter, getErr := e.registry.Get(clientProtocol)
|
||||
if getErr == nil {
|
||||
rewrittenBody, rewriteErr := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
|
||||
if rewriteErr != nil {
|
||||
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
|
||||
zap.Error(rewriteErr),
|
||||
zap.String("interface", string(interfaceType)))
|
||||
} else {
|
||||
return &HTTPResponseSpec{
|
||||
StatusCode: spec.StatusCode,
|
||||
Headers: spec.Headers,
|
||||
Body: rewrittenBody,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return &spec, nil
|
||||
}
|
||||
|
||||
@@ -127,7 +168,7 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
|
||||
return nil, err
|
||||
}
|
||||
|
||||
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body)
|
||||
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body, modelOverride)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -139,9 +180,16 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateStreamConverter 创建流式转换器
|
||||
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string) (StreamConverter, error) {
|
||||
// CreateStreamConverter 创建流式转换器,modelOverride 用于跨协议场景覆写 model 字段
|
||||
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string, modelOverride string, interfaceType InterfaceType) (StreamConverter, error) {
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
|
||||
if modelOverride != "" {
|
||||
adapter, getErr := e.registry.Get(clientProtocol)
|
||||
if getErr == nil {
|
||||
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
|
||||
}
|
||||
}
|
||||
return NewPassthroughStreamConverter(), nil
|
||||
}
|
||||
|
||||
@@ -155,9 +203,9 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
|
||||
}
|
||||
|
||||
ctx := ConversionContext{
|
||||
ConversionID: uuid.New().String(),
|
||||
InterfaceType: InterfaceTypeChat,
|
||||
Timestamp: time.Now(),
|
||||
ConversionID: uuid.New().String(),
|
||||
InterfaceType: interfaceType,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
return NewCanonicalStreamConverterWithMiddleware(
|
||||
@@ -167,6 +215,7 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
|
||||
ctx,
|
||||
clientProtocol,
|
||||
providerProtocol,
|
||||
modelOverride,
|
||||
), nil
|
||||
}
|
||||
|
||||
@@ -192,11 +241,11 @@ func (e *ConversionEngine) convertBody(interfaceType InterfaceType, clientAdapte
|
||||
}
|
||||
}
|
||||
|
||||
// convertResponseBody 转换响应体
|
||||
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
// convertResponseBody 转换响应体,modelOverride 非空时在 canonical 层面覆写 Model 字段
|
||||
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
switch interfaceType {
|
||||
case InterfaceTypeChat:
|
||||
return e.convertChatResponseBody(clientAdapter, providerAdapter, body)
|
||||
return e.convertChatResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||||
case InterfaceTypeModels:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) {
|
||||
return body, nil
|
||||
@@ -211,12 +260,12 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body)
|
||||
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||||
case InterfaceTypeRerank:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body)
|
||||
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
@@ -225,7 +274,7 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie
|
||||
func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
canonicalReq, err := clientAdapter.DecodeRequest(body)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeJSONParseError, "解码请求失败").WithCause(err)
|
||||
return nil, NewRequestJSONParseError("解码请求失败", err)
|
||||
}
|
||||
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
@@ -233,6 +282,9 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if containsUnsupportedMultimodal(canonicalReq) {
|
||||
return nil, NewConversionError(ErrorCodeUnsupportedMultimodal, "跨协议暂不支持多模态内容")
|
||||
}
|
||||
|
||||
encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider)
|
||||
if err != nil {
|
||||
@@ -241,10 +293,13 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
canonicalResp, err := providerAdapter.DecodeResponse(body)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
|
||||
return nil, NewResponseJSONParseError("解码响应失败", err)
|
||||
}
|
||||
if modelOverride != "" {
|
||||
canonicalResp.Model = modelOverride
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeResponse(canonicalResp)
|
||||
if err != nil {
|
||||
@@ -256,12 +311,12 @@ func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapte
|
||||
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
models, err := providerAdapter.DecodeModelsResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeModelsResponse(models)
|
||||
if err != nil {
|
||||
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return encoded, nil
|
||||
@@ -270,12 +325,12 @@ func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdap
|
||||
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
info, err := providerAdapter.DecodeModelInfoResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeModelInfoResponse(info)
|
||||
if err != nil {
|
||||
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return encoded, nil
|
||||
@@ -284,36 +339,43 @@ func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerA
|
||||
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
req, err := clientAdapter.DecodeEmbeddingRequest(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return providerAdapter.EncodeEmbeddingRequest(req, provider)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
if modelOverride != "" {
|
||||
resp.Model = modelOverride
|
||||
}
|
||||
return clientAdapter.EncodeEmbeddingResponse(resp)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
req, err := clientAdapter.DecodeRerankRequest(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return providerAdapter.EncodeRerankRequest(req, provider)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
resp, err := providerAdapter.DecodeRerankResponse(body)
|
||||
if err != nil {
|
||||
return body, nil
|
||||
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
resp, decodeErr := providerAdapter.DecodeRerankResponse(body)
|
||||
if decodeErr == nil {
|
||||
if modelOverride != "" {
|
||||
resp.Model = modelOverride
|
||||
}
|
||||
return clientAdapter.EncodeRerankResponse(resp)
|
||||
}
|
||||
return clientAdapter.EncodeRerankResponse(resp)
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// DetectInterfaceType 检测接口类型
|
||||
@@ -322,6 +384,7 @@ func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string
|
||||
if err != nil {
|
||||
return InterfaceTypePassthrough, err
|
||||
}
|
||||
nativePath, _ = splitRequestPath(nativePath)
|
||||
return adapter.DetectInterfaceType(nativePath), nil
|
||||
}
|
||||
|
||||
@@ -335,9 +398,56 @@ func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol stri
|
||||
"type": "internal_error",
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(fallback)
|
||||
return body, 500, nil
|
||||
body, marshalErr := json.Marshal(fallback)
|
||||
if marshalErr == nil {
|
||||
return body, 500, nil
|
||||
}
|
||||
|
||||
return []byte(`{"error":{"message":"internal error","type":"internal_error"}}`), 500, nil
|
||||
}
|
||||
body, statusCode := adapter.EncodeError(err)
|
||||
return body, statusCode, nil
|
||||
}
|
||||
|
||||
func splitRequestPath(rawPath string) (string, string) {
|
||||
path, query, found := strings.Cut(rawPath, "?")
|
||||
if !found {
|
||||
return rawPath, ""
|
||||
}
|
||||
return path, query
|
||||
}
|
||||
|
||||
func appendRawQuery(path, rawQuery string) string {
|
||||
if rawQuery == "" {
|
||||
return path
|
||||
}
|
||||
if strings.Contains(path, "?") {
|
||||
return path + "&" + rawQuery
|
||||
}
|
||||
return path + "?" + rawQuery
|
||||
}
|
||||
|
||||
func joinBaseURL(baseURL, path string) string {
|
||||
if baseURL == "" {
|
||||
return path
|
||||
}
|
||||
if path == "" {
|
||||
return baseURL
|
||||
}
|
||||
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
|
||||
}
|
||||
|
||||
func containsUnsupportedMultimodal(req *canonical.CanonicalRequest) bool {
|
||||
if req == nil {
|
||||
return false
|
||||
}
|
||||
for _, msg := range req.Messages {
|
||||
for _, block := range msg.Content {
|
||||
switch block.Type {
|
||||
case "image", "audio", "video", "file":
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
63
backend/internal/conversion/engine_adapter_test.go
Normal file
63
backend/internal/conversion/engine_adapter_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package conversion_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/conversion/openai"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestConvertHttpRequest_SameProtocolUsesAdapterBuildURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
adapter conversion.ProtocolAdapter
|
||||
clientProtocol string
|
||||
providerProtocol string
|
||||
baseURL string
|
||||
nativePath string
|
||||
expectedURL string
|
||||
body []byte
|
||||
}{
|
||||
{
|
||||
name: "openai base url includes version path",
|
||||
adapter: openai.NewAdapter(),
|
||||
clientProtocol: "openai",
|
||||
providerProtocol: "openai",
|
||||
baseURL: "http://example.com/v1",
|
||||
nativePath: "/chat/completions",
|
||||
expectedURL: "http://example.com/v1/chat/completions",
|
||||
body: []byte(`{"model":"gpt-4","messages":[]}`),
|
||||
},
|
||||
{
|
||||
name: "anthropic native path keeps v1",
|
||||
adapter: anthropic.NewAdapter(),
|
||||
clientProtocol: "anthropic",
|
||||
providerProtocol: "anthropic",
|
||||
baseURL: "http://example.com",
|
||||
nativePath: "/v1/messages",
|
||||
expectedURL: "http://example.com/v1/messages",
|
||||
body: []byte(`{"model":"claude","messages":[]}`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
engine := conversion.NewConversionEngine(registry, zap.NewNop())
|
||||
require.NoError(t, registry.Register(tt.adapter))
|
||||
|
||||
out, err := engine.ConvertHttpRequest(conversion.HTTPRequestSpec{
|
||||
URL: tt.nativePath,
|
||||
Method: "POST",
|
||||
Body: tt.body,
|
||||
}, tt.clientProtocol, tt.providerProtocol, conversion.NewTargetProvider(tt.baseURL, "key", "upstream-model"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedURL, out.URL)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestConversionError_WithProviderProtocol(t *testing.T) {
|
||||
@@ -39,7 +40,7 @@ func TestConversionError_FullBuilder(t *testing.T) {
|
||||
|
||||
func TestEngine_Use(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
called := false
|
||||
engine.Use(&testMiddleware{fn: func(req *canonical.CanonicalRequest, cp, pp string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
|
||||
called = true
|
||||
@@ -58,7 +59,7 @@ func TestEngine_Use(t *testing.T) {
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
_, err := engine.ConvertHttpRequest(HTTPRequestSpec{
|
||||
URL: "/v1/chat/completions", Method: "POST", Body: []byte(`{}`),
|
||||
URL: "/chat/completions", Method: "POST", Body: []byte(`{}`),
|
||||
}, "client", "provider", NewTargetProvider("https://example.com", "key", "model"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
@@ -66,7 +67,7 @@ func TestEngine_Use(t *testing.T) {
|
||||
|
||||
func TestConvertHttpRequest_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
return nil, errors.New("decode failed")
|
||||
@@ -75,14 +76,14 @@ func TestConvertHttpRequest_DecodeError(t *testing.T) {
|
||||
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
|
||||
|
||||
_, err := engine.ConvertHttpRequest(HTTPRequestSpec{
|
||||
URL: "/v1/chat/completions", Method: "POST", Body: []byte(`{}`),
|
||||
URL: "/chat/completions", Method: "POST", Body: []byte(`{}`),
|
||||
}, "client", "provider", NewTargetProvider("", "", ""))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_EncodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) {
|
||||
@@ -91,14 +92,14 @@ func TestConvertHttpRequest_EncodeError(t *testing.T) {
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
_, err := engine.ConvertHttpRequest(HTTPRequestSpec{
|
||||
URL: "/v1/chat/completions", Method: "POST", Body: []byte(`{}`),
|
||||
URL: "/chat/completions", Method: "POST", Body: []byte(`{}`),
|
||||
}, "client", "provider", NewTargetProvider("", "", ""))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
@@ -113,7 +114,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"id":"resp-1"}`),
|
||||
}, "client", "provider", InterfaceTypeChat)
|
||||
}, "client", "provider", InterfaceTypeChat, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 200, result.StatusCode)
|
||||
assert.Contains(t, string(result.Body), "resp-1")
|
||||
@@ -121,7 +122,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
|
||||
|
||||
func TestConvertHttpResponse_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
return nil, errors.New("decode error")
|
||||
@@ -129,13 +130,13 @@ func TestConvertHttpResponse_DecodeError(t *testing.T) {
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||
|
||||
_, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat)
|
||||
_, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat, "")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.ifaceType = InterfaceTypeEmbeddings
|
||||
@@ -158,7 +159,7 @@ func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
|
||||
|
||||
func TestConvertHttpRequest_RerankInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.ifaceType = InterfaceTypeRerank
|
||||
@@ -178,7 +179,7 @@ func TestConvertHttpRequest_RerankInterface(t *testing.T) {
|
||||
|
||||
func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
@@ -189,14 +190,14 @@ func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"object":"list","data":[],"model":"test"}`),
|
||||
}, "client", "provider", InterfaceTypeEmbeddings)
|
||||
}, "client", "provider", InterfaceTypeEmbeddings, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_RerankInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
@@ -207,14 +208,14 @@ func TestConvertHttpResponse_RerankInterface(t *testing.T) {
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"results":[],"model":"test"}`),
|
||||
}, "client", "provider", InterfaceTypeRerank)
|
||||
}, "client", "provider", InterfaceTypeRerank, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.ifaceType = InterfaceTypeModels
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
@@ -224,7 +225,7 @@ func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
|
||||
|
||||
body := []byte(`{"object":"list","data":[]}`)
|
||||
result, err := engine.ConvertHttpRequest(HTTPRequestSpec{
|
||||
URL: "/v1/models", Method: "GET", Body: body,
|
||||
URL: "/models", Method: "GET", Body: body,
|
||||
}, "client", "provider", NewTargetProvider("https://example.com", "key", ""))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result.Body)
|
||||
@@ -232,7 +233,7 @@ func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
|
||||
|
||||
func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModels: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
@@ -242,14 +243,14 @@ func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"object":"list","data":[]}`),
|
||||
}, "client", "provider", InterfaceTypeModels)
|
||||
}, "client", "provider", InterfaceTypeModels, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModelInfo: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
@@ -259,7 +260,7 @@ func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"id":"gpt-4","object":"model"}`),
|
||||
}, "client", "provider", InterfaceTypeModelInfo)
|
||||
}, "client", "provider", InterfaceTypeModelInfo, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
@@ -321,3 +322,58 @@ func (m *testMiddleware) InterceptStreamEvent(event *canonical.CanonicalStreamEv
|
||||
}
|
||||
|
||||
var _ = json.Marshal
|
||||
|
||||
func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeEmbeddingReqFn = func(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
return nil, errors.New("decode embedding failed")
|
||||
}
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"model":"text-embedding","input":"hello"}`)
|
||||
result, err := engine.convertEmbeddingBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
func TestConvertRerankBody_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeRerankReqFn = func(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
return nil, errors.New("decode rerank failed")
|
||||
}
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"model":"rerank","query":"test","documents":["a"]}`)
|
||||
result, err := engine.convertRerankBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
func TestConvertBody_UnknownInterfaceType(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"test":"data"}`)
|
||||
result, err := engine.convertBody(InterfaceType("UNKNOWN"), clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package conversion
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
@@ -13,16 +14,20 @@ import (
|
||||
|
||||
// mockProtocolAdapter 模拟协议适配器
|
||||
type mockProtocolAdapter struct {
|
||||
protocolName string
|
||||
passthrough bool
|
||||
ifaceType InterfaceType
|
||||
supportsIface map[InterfaceType]bool
|
||||
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
|
||||
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
|
||||
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
|
||||
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
|
||||
streamDecoderFn func() StreamDecoder
|
||||
streamEncoderFn func() StreamEncoder
|
||||
protocolName string
|
||||
passthrough bool
|
||||
ifaceType InterfaceType
|
||||
supportsIface map[InterfaceType]bool
|
||||
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
|
||||
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
|
||||
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
|
||||
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
|
||||
streamDecoderFn func() StreamDecoder
|
||||
streamEncoderFn func() StreamEncoder
|
||||
rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
decodeEmbeddingReqFn func([]byte) (*canonical.CanonicalEmbeddingRequest, error)
|
||||
decodeRerankReqFn func([]byte) (*canonical.CanonicalRerankRequest, error)
|
||||
}
|
||||
|
||||
func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
|
||||
@@ -34,8 +39,8 @@ func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName }
|
||||
func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" }
|
||||
func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName }
|
||||
func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" }
|
||||
func (m *mockProtocolAdapter) SupportsPassthrough() bool { return m.passthrough }
|
||||
|
||||
func (m *mockProtocolAdapter) DetectInterfaceType(nativePath string) InterfaceType {
|
||||
@@ -124,6 +129,9 @@ func (m *mockProtocolAdapter) EncodeModelInfoResponse(info *canonical.CanonicalM
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
if m.decodeEmbeddingReqFn != nil {
|
||||
return m.decodeEmbeddingReqFn(raw)
|
||||
}
|
||||
return &canonical.CanonicalEmbeddingRequest{}, nil
|
||||
}
|
||||
|
||||
@@ -140,6 +148,9 @@ func (m *mockProtocolAdapter) EncodeEmbeddingResponse(resp *canonical.CanonicalE
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
if m.decodeRerankReqFn != nil {
|
||||
return m.decodeRerankReqFn(raw)
|
||||
}
|
||||
return &canonical.CanonicalRerankRequest{}, nil
|
||||
}
|
||||
|
||||
@@ -155,23 +166,47 @@ func (m *mockProtocolAdapter) EncodeRerankResponse(resp *canonical.CanonicalRera
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) ExtractModelName(body []byte, ifaceType InterfaceType) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
if m.rewriteReqFn != nil {
|
||||
return m.rewriteReqFn(body, newModel, ifaceType)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
if m.rewriteRespFn != nil {
|
||||
return m.rewriteRespFn(body, newModel, ifaceType)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// noopStreamDecoder 空流式解码器
|
||||
type noopStreamDecoder struct{}
|
||||
|
||||
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil }
|
||||
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
|
||||
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
return nil
|
||||
}
|
||||
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
|
||||
|
||||
// noopStreamEncoder 空流式编码器
|
||||
type noopStreamEncoder struct{}
|
||||
|
||||
func (e *noopStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { return nil }
|
||||
func (e *noopStreamEncoder) Flush() [][]byte { return nil }
|
||||
func (e *noopStreamEncoder) Flush() [][]byte { return nil }
|
||||
|
||||
// ============ 测试用例 ============
|
||||
|
||||
func TestNewConversionEngine(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
assert.NotNil(t, engine)
|
||||
assert.Equal(t, registry, engine.GetRegistry())
|
||||
}
|
||||
@@ -179,7 +214,7 @@ func TestNewConversionEngine(t *testing.T) {
|
||||
func TestNewConversionEngine_LoggerInjection(t *testing.T) {
|
||||
t.Run("nil_logger_uses_global", func(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
assert.NotNil(t, engine.logger)
|
||||
})
|
||||
|
||||
@@ -187,13 +222,14 @@ func TestNewConversionEngine_LoggerInjection(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
customLogger := zap.NewNop()
|
||||
engine := NewConversionEngine(registry, customLogger)
|
||||
assert.Equal(t, customLogger, engine.logger)
|
||||
assert.NotNil(t, engine.logger)
|
||||
assert.Contains(t, engine.logger.Name(), "conversion.engine")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegisterAdapter(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
adapter := newMockAdapter("test-proto", true)
|
||||
err := engine.RegisterAdapter(adapter)
|
||||
@@ -205,7 +241,7 @@ func TestRegisterAdapter(t *testing.T) {
|
||||
|
||||
func TestIsPassthrough_SameProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
adapter := newMockAdapter("openai", true)
|
||||
_ = engine.RegisterAdapter(adapter)
|
||||
|
||||
@@ -214,7 +250,7 @@ func TestIsPassthrough_SameProtocol(t *testing.T) {
|
||||
|
||||
func TestIsPassthrough_DifferentProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
_ = engine.RegisterAdapter(newMockAdapter("anthropic", true))
|
||||
|
||||
@@ -223,7 +259,7 @@ func TestIsPassthrough_DifferentProtocol(t *testing.T) {
|
||||
|
||||
func TestIsPassthrough_NoPassthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("custom", false))
|
||||
|
||||
assert.False(t, engine.IsPassthrough("custom", "custom"))
|
||||
@@ -231,19 +267,19 @@ func TestIsPassthrough_NoPassthrough(t *testing.T) {
|
||||
|
||||
func TestDetectInterfaceType(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
adapter := newMockAdapter("test", true)
|
||||
adapter.ifaceType = InterfaceTypeChat
|
||||
_ = engine.RegisterAdapter(adapter)
|
||||
|
||||
ifaceType, err := engine.DetectInterfaceType("/v1/chat/completions", "test")
|
||||
ifaceType, err := engine.DetectInterfaceType("/chat/completions", "test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, InterfaceTypeChat, ifaceType)
|
||||
}
|
||||
|
||||
func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
_, err := engine.DetectInterfaceType("/v1/chat", "nonexistent")
|
||||
assert.Error(t, err)
|
||||
@@ -251,25 +287,39 @@ func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
|
||||
|
||||
func TestConvertHttpRequest_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
openaiAdapter := &buildURLMockAdapter{
|
||||
mockProtocolAdapter: newMockAdapter("openai", true),
|
||||
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
|
||||
if interfaceType == InterfaceTypeChat {
|
||||
return "/chat/completions"
|
||||
}
|
||||
return nativePath
|
||||
},
|
||||
}
|
||||
openaiAdapter.ifaceType = InterfaceTypeChat
|
||||
openaiAdapter.supportsIface[InterfaceTypeChat] = true
|
||||
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
return []byte(`{"model":"` + newModel + `","messages":[{"role":"user","content":"hi"}]}`), nil
|
||||
}
|
||||
_ = engine.RegisterAdapter(openaiAdapter)
|
||||
|
||||
provider := NewTargetProvider("https://api.openai.com", "sk-test", "gpt-4")
|
||||
provider := NewTargetProvider("https://api.openai.com/v1", "sk-test", "gpt-4")
|
||||
spec := HTTPRequestSpec{
|
||||
URL: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`),
|
||||
Body: []byte(`{"model":"openai/gpt-4","messages":[{"role":"user","content":"hi"}]}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpRequest(spec, "openai", "openai", provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
|
||||
assert.Equal(t, spec.Body, result.Body)
|
||||
assert.JSONEq(t, `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`, string(result.Body))
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client-proto", false)
|
||||
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
@@ -299,9 +349,80 @@ func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
|
||||
assert.NotNil(t, result.Body)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_UsesProviderAdapterBuildURL(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
openaiAdapter := &buildURLMockAdapter{
|
||||
mockProtocolAdapter: newMockAdapter("openai", true),
|
||||
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
|
||||
if interfaceType == InterfaceTypeChat {
|
||||
return "/chat/completions"
|
||||
}
|
||||
return nativePath
|
||||
},
|
||||
}
|
||||
openaiAdapter.ifaceType = InterfaceTypeChat
|
||||
openaiAdapter.supportsIface[InterfaceTypeChat] = true
|
||||
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
return []byte(`{"model":"` + newModel + `"}`), nil
|
||||
}
|
||||
require.NoError(t, registry.Register(openaiAdapter))
|
||||
|
||||
anthropicAdapter := &buildURLMockAdapter{
|
||||
mockProtocolAdapter: newMockAdapter("anthropic", false),
|
||||
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
|
||||
if interfaceType == InterfaceTypeChat {
|
||||
return "/v1/messages"
|
||||
}
|
||||
return nativePath
|
||||
},
|
||||
}
|
||||
anthropicAdapter.ifaceType = InterfaceTypeChat
|
||||
anthropicAdapter.supportsIface[InterfaceTypeChat] = true
|
||||
require.NoError(t, registry.Register(anthropicAdapter))
|
||||
|
||||
t.Run("OpenAI to Anthropic", func(t *testing.T) {
|
||||
provider := NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
spec := HTTPRequestSpec{
|
||||
URL: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Body: []byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"max_tokens":16}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpRequest(spec, "openai", "anthropic", provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://api.anthropic.com/v1/messages", result.URL)
|
||||
})
|
||||
|
||||
t.Run("Anthropic to OpenAI", func(t *testing.T) {
|
||||
provider := NewTargetProvider("https://api.openai.com/v1", "key", "gpt-4")
|
||||
spec := HTTPRequestSpec{
|
||||
URL: "/v1/messages",
|
||||
Method: "POST",
|
||||
Body: []byte(`{"model":"p1/claude-3","max_tokens":16,"messages":[{"role":"user","content":"hi"}]}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpRequest(spec, "anthropic", "openai", provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
|
||||
})
|
||||
}
|
||||
|
||||
type buildURLMockAdapter struct {
|
||||
*mockProtocolAdapter
|
||||
buildURLFn func(string, InterfaceType) string
|
||||
}
|
||||
|
||||
func (m *buildURLMockAdapter) BuildUrl(nativePath string, interfaceType InterfaceType) string {
|
||||
if m.buildURLFn != nil {
|
||||
return m.buildURLFn(nativePath, interfaceType)
|
||||
}
|
||||
return m.mockProtocolAdapter.BuildUrl(nativePath, interfaceType)
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
spec := HTTPResponseSpec{
|
||||
@@ -309,7 +430,7 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) {
|
||||
Body: []byte(`{"id":"123"}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat)
|
||||
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 200, result.StatusCode)
|
||||
assert.Equal(t, spec.Body, result.Body)
|
||||
@@ -317,10 +438,10 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) {
|
||||
|
||||
func TestCreateStreamConverter_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
converter, err := engine.CreateStreamConverter("openai", "openai")
|
||||
converter, err := engine.CreateStreamConverter("openai", "openai", "", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
_, ok := converter.(*PassthroughStreamConverter)
|
||||
assert.True(t, ok)
|
||||
@@ -328,11 +449,11 @@ func TestCreateStreamConverter_Passthrough(t *testing.T) {
|
||||
|
||||
func TestCreateStreamConverter_Canonical(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
|
||||
|
||||
converter, err := engine.CreateStreamConverter("client", "provider")
|
||||
converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
_, ok := converter.(*CanonicalStreamConverter)
|
||||
assert.True(t, ok)
|
||||
@@ -340,7 +461,7 @@ func TestCreateStreamConverter_Canonical(t *testing.T) {
|
||||
|
||||
func TestEncodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
|
||||
@@ -352,7 +473,7 @@ func TestEncodeError(t *testing.T) {
|
||||
|
||||
func TestEncodeError_NonExistentProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
|
||||
body, statusCode, err := engine.EncodeError(convErr, "nonexistent")
|
||||
@@ -380,3 +501,233 @@ func TestRegistry_GetNonExistent(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "未找到适配器")
|
||||
}
|
||||
|
||||
// ============ modelOverride 测试 ============
|
||||
|
||||
func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
return json.Marshal(map[string]any{"model": resp.Model})
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
return &canonical.CanonicalResponse{ID: "test", Model: "native-model", Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}, nil
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
spec := HTTPResponseSpec{
|
||||
StatusCode: 200,
|
||||
Body: []byte(`{"model":"native-model"}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpResponse(spec, "client", "provider", InterfaceTypeChat, "provider/gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
var resp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(result.Body, &resp))
|
||||
assert.Equal(t, "provider/gpt-4", resp["model"])
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
// 使用真实 OpenAI adapter 验证 Smart Passthrough 改写
|
||||
openaiAdapter := newMockAdapter("openai", true)
|
||||
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
return json.Marshal(m)
|
||||
}
|
||||
_ = engine.RegisterAdapter(openaiAdapter)
|
||||
|
||||
spec := HTTPResponseSpec{
|
||||
StatusCode: 200,
|
||||
Body: []byte(`{"id":"resp-1","model":"gpt-4"}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "openai/gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
var resp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(result.Body, &resp))
|
||||
assert.Equal(t, "openai/gpt-4", resp["model"])
|
||||
assert.Equal(t, "resp-1", resp["id"])
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
openaiAdapter := newMockAdapter("openai", true)
|
||||
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
return json.Marshal(m)
|
||||
}
|
||||
_ = engine.RegisterAdapter(openaiAdapter)
|
||||
|
||||
converter, err := engine.CreateStreamConverter("openai", "openai", "openai/gpt-4", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, ok := converter.(*SmartPassthroughStreamConverter)
|
||||
assert.True(t, ok)
|
||||
|
||||
// 验证 SSE frame 中的 data JSON 被改写
|
||||
chunks := converter.ProcessChunk([]byte(`data: {"model":"gpt-4","choices":[]}` + "\n\n"))
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
var resp map[string]interface{}
|
||||
payload := strings.TrimPrefix(strings.TrimSpace(string(chunks[0])), "data: ")
|
||||
require.NoError(t, json.Unmarshal([]byte(payload), &resp))
|
||||
assert.Equal(t, "openai/gpt-4", resp["model"])
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
// provider adapter 解码出含 model 的流式事件
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.streamDecoderFn = func() StreamDecoder {
|
||||
return &engineTestStreamDecoder{
|
||||
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewMessageStartEvent("msg-1", "native-model"),
|
||||
canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: "hi"}),
|
||||
canonical.NewMessageStopEvent(),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
// client adapter 编码时输出 model 字段
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.streamEncoderFn = func() StreamEncoder {
|
||||
return &engineTestStreamEncoder{
|
||||
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Message != nil {
|
||||
data, _ := json.Marshal(map[string]string{
|
||||
"type": string(event.Type),
|
||||
"model": event.Message.Model,
|
||||
})
|
||||
return [][]byte{data}
|
||||
}
|
||||
data, _ := json.Marshal(map[string]string{"type": string(event.Type)})
|
||||
return [][]byte{data}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
|
||||
converter, err := engine.CreateStreamConverter("client", "provider", "provider/gpt-4", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证类型是 CanonicalStreamConverter
|
||||
_, ok := converter.(*CanonicalStreamConverter)
|
||||
assert.True(t, ok)
|
||||
|
||||
// 处理一个 chunk,验证 model 被覆写为统一模型 ID
|
||||
chunks := converter.ProcessChunk([]byte("raw"))
|
||||
require.Len(t, chunks, 3) // message_start + content_block_start + message_stop
|
||||
|
||||
var startEvent map[string]string
|
||||
require.NoError(t, json.Unmarshal(chunks[0], &startEvent))
|
||||
assert.Equal(t, "provider/gpt-4", startEvent["model"], "跨协议流式中 modelOverride 应覆写 Message.Model")
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_ModelOverride_CrossProtocol_Empty(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.streamDecoderFn = func() StreamDecoder {
|
||||
return &engineTestStreamDecoder{
|
||||
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewMessageStartEvent("msg-1", "native-model"),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.streamEncoderFn = func() StreamEncoder {
|
||||
return &engineTestStreamEncoder{
|
||||
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Message != nil {
|
||||
data, _ := json.Marshal(map[string]string{
|
||||
"model": event.Message.Model,
|
||||
})
|
||||
return [][]byte{data}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
|
||||
// modelOverride 为空,不应覆写
|
||||
converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
chunks := converter.ProcessChunk([]byte("raw"))
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
var resp map[string]string
|
||||
require.NoError(t, json.Unmarshal(chunks[0], &resp))
|
||||
assert.Equal(t, "native-model", resp["model"], "modelOverride 为空时不应覆写")
|
||||
}
|
||||
|
||||
// engineTestStreamDecoder 可控的流式解码器(用于 engine_test)
|
||||
type engineTestStreamDecoder struct {
|
||||
processFn func([]byte) []canonical.CanonicalStreamEvent
|
||||
flushFn func() []canonical.CanonicalStreamEvent
|
||||
}
|
||||
|
||||
func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.CanonicalStreamEvent {
|
||||
if d.processFn != nil {
|
||||
return d.processFn(raw)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||
if d.flushFn != nil {
|
||||
return d.flushFn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// engineTestStreamEncoder 可控的流式编码器(用于 engine_test)
|
||||
type engineTestStreamEncoder struct {
|
||||
encodeFn func(canonical.CanonicalStreamEvent) [][]byte
|
||||
flushFn func() [][]byte
|
||||
}
|
||||
|
||||
func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if e.encodeFn != nil {
|
||||
return e.encodeFn(event)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *engineTestStreamEncoder) Flush() [][]byte {
|
||||
if e.flushFn != nil {
|
||||
return e.flushFn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,17 +6,24 @@ import "fmt"
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
|
||||
ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD"
|
||||
ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE"
|
||||
ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE"
|
||||
ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR"
|
||||
ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR"
|
||||
ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR"
|
||||
ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR"
|
||||
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
|
||||
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
|
||||
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
|
||||
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
|
||||
ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD"
|
||||
ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE"
|
||||
ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE"
|
||||
ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR"
|
||||
ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR"
|
||||
ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR"
|
||||
ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR"
|
||||
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
|
||||
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
|
||||
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
|
||||
ErrorCodeUnsupportedMultimodal ErrorCode = "UNSUPPORTED_MULTIMODAL"
|
||||
)
|
||||
|
||||
const (
|
||||
ErrorDetailPhase = "phase"
|
||||
ErrorPhaseRequest = "request"
|
||||
ErrorPhaseResponse = "response"
|
||||
)
|
||||
|
||||
// ConversionError 协议转换错误
|
||||
@@ -39,6 +46,20 @@ func NewConversionError(code ErrorCode, message string) *ConversionError {
|
||||
}
|
||||
}
|
||||
|
||||
// NewRequestJSONParseError 创建请求 JSON 解析错误。
|
||||
func NewRequestJSONParseError(message string, cause error) *ConversionError {
|
||||
return NewConversionError(ErrorCodeJSONParseError, message).
|
||||
WithDetail(ErrorDetailPhase, ErrorPhaseRequest).
|
||||
WithCause(cause)
|
||||
}
|
||||
|
||||
// NewResponseJSONParseError 创建响应 JSON 解析错误。
|
||||
func NewResponseJSONParseError(message string, cause error) *ConversionError {
|
||||
return NewConversionError(ErrorCodeJSONParseError, message).
|
||||
WithDetail(ErrorDetailPhase, ErrorPhaseResponse).
|
||||
WithCause(cause)
|
||||
}
|
||||
|
||||
// WithClientProtocol 设置客户端协议
|
||||
func (e *ConversionError) WithClientProtocol(protocol string) *ConversionError {
|
||||
e.ClientProtocol = protocol
|
||||
|
||||
@@ -4,10 +4,10 @@ package conversion
|
||||
type InterfaceType string
|
||||
|
||||
const (
|
||||
InterfaceTypeChat InterfaceType = "CHAT"
|
||||
InterfaceTypeModels InterfaceType = "MODELS"
|
||||
InterfaceTypeModelInfo InterfaceType = "MODEL_INFO"
|
||||
InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS"
|
||||
InterfaceTypeRerank InterfaceType = "RERANK"
|
||||
InterfaceTypeChat InterfaceType = "CHAT"
|
||||
InterfaceTypeModels InterfaceType = "MODELS"
|
||||
InterfaceTypeModelInfo InterfaceType = "MODEL_INFO"
|
||||
InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS"
|
||||
InterfaceTypeRerank InterfaceType = "RERANK"
|
||||
InterfaceTypePassthrough InterfaceType = "PASSTHROUGH"
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
@@ -43,26 +44,31 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
|
||||
}
|
||||
}
|
||||
|
||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id})
|
||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /)
|
||||
func isModelInfoPath(path string) bool {
|
||||
if !strings.HasPrefix(path, "/v1/models/") {
|
||||
return false
|
||||
}
|
||||
suffix := path[len("/v1/models/"):]
|
||||
return suffix != "" && !strings.Contains(suffix, "/")
|
||||
return suffix != ""
|
||||
}
|
||||
|
||||
// BuildUrl 根据接口类型构建 URL
|
||||
func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string {
|
||||
switch interfaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
return "/v1/chat/completions"
|
||||
return "/chat/completions"
|
||||
case conversion.InterfaceTypeModels:
|
||||
return "/v1/models"
|
||||
return "/models"
|
||||
case conversion.InterfaceTypeModelInfo:
|
||||
if modelID, err := a.ExtractUnifiedModelID(nativePath); err == nil {
|
||||
return "/models/" + modelID
|
||||
}
|
||||
return nativePath
|
||||
case conversion.InterfaceTypeEmbeddings:
|
||||
return "/v1/embeddings"
|
||||
return "/embeddings"
|
||||
case conversion.InterfaceTypeRerank:
|
||||
return "/v1/rerank"
|
||||
return "/rerank"
|
||||
default:
|
||||
return nativePath
|
||||
}
|
||||
@@ -137,7 +143,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
||||
Code: string(err.Code),
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(errMsg)
|
||||
body, marshalErr := json.Marshal(errMsg)
|
||||
if marshalErr != nil {
|
||||
return []byte(`{"error":{"message":"internal error","type":"internal_error","code":"INTERNAL_ERROR"}}`), statusCode
|
||||
}
|
||||
return body, statusCode
|
||||
}
|
||||
|
||||
@@ -216,3 +225,92 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe
|
||||
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
return encodeRerankResponse(resp)
|
||||
}
|
||||
|
||||
// ExtractUnifiedModelID 从路径中提取统一模型 ID(/v1/models/{provider_id}/{model_name})
|
||||
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||
if !strings.HasPrefix(nativePath, "/v1/models/") {
|
||||
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
|
||||
}
|
||||
suffix := nativePath[len("/v1/models/"):]
|
||||
if suffix == "" {
|
||||
return "", fmt.Errorf("路径缺少模型 ID")
|
||||
}
|
||||
return suffix, nil
|
||||
}
|
||||
|
||||
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
|
||||
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
|
||||
raw, exists := m["model"]
|
||||
if !exists {
|
||||
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
|
||||
}
|
||||
var current string
|
||||
if err := json.Unmarshal(raw, ¤t); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
}
|
||||
return current, rewriteFunc, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractModelName 从请求体中提取 model 值
|
||||
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
|
||||
model, _, err := locateModelFieldInRequest(body, ifaceType)
|
||||
return model, err
|
||||
}
|
||||
|
||||
// RewriteRequestModelName 最小化改写请求体中的 model 字段
|
||||
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rewriteFunc(newModel)
|
||||
}
|
||||
|
||||
// RewriteResponseModelName 最小化改写响应体中的 model 字段
|
||||
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
|
||||
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
case conversion.InterfaceTypeRerank:
|
||||
// Rerank 响应:存在 model 字段则改写,不存在则不添加
|
||||
if _, exists := m["model"]; exists {
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
}
|
||||
return json.Marshal(m)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,10 +30,10 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
|
||||
}{
|
||||
{"聊天补全", "/v1/chat/completions", conversion.InterfaceTypeChat},
|
||||
{"模型列表", "/v1/models", conversion.InterfaceTypeModels},
|
||||
{"模型详情", "/v1/models/gpt-4", conversion.InterfaceTypeModelInfo},
|
||||
{"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo},
|
||||
{"嵌入接口", "/v1/embeddings", conversion.InterfaceTypeEmbeddings},
|
||||
{"重排序接口", "/v1/rerank", conversion.InterfaceTypeRerank},
|
||||
{"未知路径", "/v1/unknown", conversion.InterfaceTypePassthrough},
|
||||
{"未知路径", "/unknown", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -44,20 +44,43 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_OldPathsBecomePassthrough(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"/chat/completions", conversion.InterfaceTypePassthrough},
|
||||
{"/models", conversion.InterfaceTypePassthrough},
|
||||
{"/models/gpt-4.1", conversion.InterfaceTypePassthrough},
|
||||
{"/embeddings", conversion.InterfaceTypePassthrough},
|
||||
{"/rerank", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nativePath string
|
||||
name string
|
||||
nativePath string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected string
|
||||
expected string
|
||||
}{
|
||||
{"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/v1/chat/completions"},
|
||||
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
|
||||
{"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/v1/embeddings"},
|
||||
{"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/v1/rerank"},
|
||||
{"默认透传", "/v1/other", conversion.InterfaceTypePassthrough, "/v1/other"},
|
||||
{"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
|
||||
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/models"},
|
||||
{"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo, "/models/openai/gpt-4"},
|
||||
{"复杂模型详情", "/v1/models/azure/accounts/org/models/gpt-4", conversion.InterfaceTypeModelInfo, "/models/azure/accounts/org/models/gpt-4"},
|
||||
{"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/embeddings"},
|
||||
{"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/rerank"},
|
||||
{"默认透传", "/other", conversion.InterfaceTypePassthrough, "/other"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -92,9 +115,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
name string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected bool
|
||||
expected bool
|
||||
}{
|
||||
{"聊天", conversion.InterfaceTypeChat, true},
|
||||
{"模型", conversion.InterfaceTypeModels, true},
|
||||
@@ -118,13 +141,13 @@ func TestIsModelInfoPath(t *testing.T) {
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"model_info", "/v1/models/gpt-4", true},
|
||||
{"model_info_with_dots", "/v1/models/gpt-4.1-preview", true},
|
||||
{"model_info", "/v1/models/openai/gpt-4", true},
|
||||
{"model_info_with_dots", "/v1/models/openai/gpt-4.1-preview", true},
|
||||
{"models_list", "/v1/models", false},
|
||||
{"nested_path", "/v1/models/gpt-4/versions", false},
|
||||
{"nested_path", "/v1/models/azure/accounts/org-123/models/gpt-4", true},
|
||||
{"empty_suffix", "/v1/models/", false},
|
||||
{"unrelated", "/v1/chat/completions", false},
|
||||
{"partial_prefix", "/v1/model", false},
|
||||
{"partial_prefix", "/model", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -134,6 +157,27 @@ func TestIsModelInfoPath(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_ExtractUnifiedModelID(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("标准路径", func(t *testing.T) {
|
||||
modelID, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", modelID)
|
||||
})
|
||||
|
||||
t.Run("复杂路径", func(t *testing.T) {
|
||||
modelID, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "azure/accounts/org/models/gpt-4", modelID)
|
||||
})
|
||||
|
||||
t.Run("非模型详情路径报错", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdapter_EncodeError_InvalidInput(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")
|
||||
|
||||
360
backend/internal/conversion/openai/adapter_unified_test.go
Normal file
360
backend/internal/conversion/openai/adapter_unified_test.go
Normal file
@@ -0,0 +1,360 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractUnifiedModelID
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractUnifiedModelID(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("standard_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", id)
|
||||
})
|
||||
|
||||
t.Run("multi_segment_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "azure/accounts/org/models/gpt-4", id)
|
||||
})
|
||||
|
||||
t.Run("single_segment", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", id)
|
||||
})
|
||||
|
||||
t.Run("non_model_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/chat/completions")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty_suffix", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models/")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("models_list_no_slash", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unrelated_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/other")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", model)
|
||||
})
|
||||
|
||||
t.Run("embedding", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/text-embedding", model)
|
||||
})
|
||||
|
||||
t.Run("rerank", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/rerank","query":"test"}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/rerank", model)
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4"}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypePassthrough)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteRequestModelName
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteRequestModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "gpt-4", m["model"])
|
||||
|
||||
// messages field preserved
|
||||
msgs, ok := m["messages"]
|
||||
require.True(t, ok)
|
||||
msgsArr, ok := msgs.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgsArr, 0)
|
||||
})
|
||||
|
||||
t.Run("preserves_unknown_fields", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "gpt-4", m["model"])
|
||||
assert.Equal(t, 0.7, m["temperature"])
|
||||
})
|
||||
|
||||
t.Run("embedding", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "text-embedding", m["model"])
|
||||
assert.Equal(t, "hello", m["input"])
|
||||
})
|
||||
|
||||
t.Run("rerank", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/rerank","query":"test"}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "rerank", m["model"])
|
||||
assert.Equal(t, "test", m["query"])
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4"}`)
|
||||
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypePassthrough)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteResponseModelName
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteResponseModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4","choices":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/gpt-4", m["model"])
|
||||
|
||||
choices, ok := m["choices"]
|
||||
require.True(t, ok)
|
||||
choicesArr, ok := choices.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, choicesArr, 0)
|
||||
})
|
||||
|
||||
t.Run("chat_without_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/gpt-4", m["model"])
|
||||
|
||||
choices, ok := m["choices"]
|
||||
require.True(t, ok)
|
||||
choicesArr, ok := choices.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, choicesArr, 0)
|
||||
})
|
||||
|
||||
t.Run("rerank_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"rerank","results":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/rerank", m["model"])
|
||||
})
|
||||
|
||||
t.Run("rerank_without_model_field_should_not_add", func(t *testing.T) {
|
||||
body := []byte(`{"results":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
_, hasModel := m["model"]
|
||||
assert.False(t, hasModel, "rerank response without model field should not have one added")
|
||||
})
|
||||
|
||||
t.Run("embedding_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"text-embedding","data":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/text-embedding", m["model"])
|
||||
})
|
||||
|
||||
t.Run("embedding_without_model_field_adds", func(t *testing.T) {
|
||||
body := []byte(`{"data":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/text-embedding", m["model"])
|
||||
})
|
||||
|
||||
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypePassthrough)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(body), string(rewritten))
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName and RewriteRequest consistency
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"openai/gpt-4","messages":[],"temperature":0.7}`)
|
||||
|
||||
// Extract the unified model ID from the body
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", extracted)
|
||||
|
||||
// Rewrite to the native model name
|
||||
rewritten, err := a.RewriteRequestModelName(original, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract again from the rewritten body to verify the same location was targeted
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", afterRewrite)
|
||||
|
||||
// Verify other fields are preserved
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, 0.7, m["temperature"])
|
||||
})
|
||||
|
||||
t.Run("embedding_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
|
||||
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/text-embedding", extracted)
|
||||
|
||||
rewritten, err := a.RewriteRequestModelName(original, "text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "text-embedding", afterRewrite)
|
||||
})
|
||||
|
||||
t.Run("rerank_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"openai/rerank","query":"test"}`)
|
||||
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/rerank", extracted)
|
||||
|
||||
rewritten, err := a.RewriteRequestModelName(original, "rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "rerank", afterRewrite)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isModelInfoPath (additional unified model ID cases)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"simple_model_id", "/v1/models/gpt-4", true},
|
||||
{"unified_model_id_with_slash", "/v1/models/openai/gpt-4", true},
|
||||
{"models_list", "/v1/models", false},
|
||||
{"models_list_trailing_slash", "/v1/models/", false},
|
||||
{"chat_completions", "/v1/chat/completions", false},
|
||||
{"deeply_nested", "/v1/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -215,10 +215,16 @@ func decodeUserContent(content any) []canonical.ContentBlock {
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
t, _ := m["type"].(string)
|
||||
t, ok := m["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
text, ok := m["text"].(string)
|
||||
if !ok {
|
||||
text = ""
|
||||
}
|
||||
blocks = append(blocks, canonical.NewTextBlock(text))
|
||||
case "image_url":
|
||||
blocks = append(blocks, canonical.ContentBlock{Type: "image"})
|
||||
@@ -242,9 +248,9 @@ func decodeUserContent(content any) []canonical.ContentBlock {
|
||||
|
||||
// contentPart 内容部分
|
||||
type contentPart struct {
|
||||
Type string
|
||||
Text string
|
||||
Refusal string
|
||||
Type string
|
||||
Text string
|
||||
Refusal string
|
||||
}
|
||||
|
||||
// decodeContentParts 解码内容部分
|
||||
@@ -256,13 +262,22 @@ func decodeContentParts(content any) []contentPart {
|
||||
var result []contentPart
|
||||
for _, item := range parts {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
t, _ := m["type"].(string)
|
||||
t, ok := m["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
text, ok := m["text"].(string)
|
||||
if !ok {
|
||||
text = ""
|
||||
}
|
||||
result = append(result, contentPart{Type: "text", Text: text})
|
||||
case "refusal":
|
||||
refusal, _ := m["refusal"].(string)
|
||||
refusal, ok := m["refusal"].(string)
|
||||
if !ok {
|
||||
refusal = ""
|
||||
}
|
||||
result = append(result, contentPart{Type: "refusal", Refusal: refusal})
|
||||
}
|
||||
}
|
||||
@@ -307,21 +322,33 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
case map[string]any:
|
||||
t, _ := v["type"].(string)
|
||||
t, ok := v["type"].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch t {
|
||||
case "function":
|
||||
if fn, ok := v["function"].(map[string]any); ok {
|
||||
name, _ := fn["name"].(string)
|
||||
name, ok := fn["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
case "custom":
|
||||
if custom, ok := v["custom"].(map[string]any); ok {
|
||||
name, _ := custom["name"].(string)
|
||||
name, ok := custom["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
case "allowed_tools":
|
||||
if at, ok := v["allowed_tools"].(map[string]any); ok {
|
||||
mode, _ := at["mode"].(string)
|
||||
mode, ok := at["mode"].(string)
|
||||
if !ok {
|
||||
mode = ""
|
||||
}
|
||||
if mode == "required" {
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
@@ -443,7 +470,7 @@ func decodeDeprecatedFields(req *ChatCompletionRequest) {
|
||||
case map[string]any:
|
||||
if name, ok := v["name"].(string); ok {
|
||||
req.ToolChoice = map[string]any{
|
||||
"type": "function",
|
||||
"type": "function",
|
||||
"function": map[string]any{"name": name},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -409,3 +409,25 @@ func TestDecodeResponse_Refusal(t *testing.T) {
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_AssistantContentArray(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello back"}]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, req.Messages, 2)
|
||||
assistantMsg := req.Messages[1]
|
||||
assert.Equal(t, canonical.RoleAssistant, assistantMsg.Role)
|
||||
assert.Len(t, assistantMsg.Content, 1)
|
||||
assert.Equal(t, "text", assistantMsg.Content[0].Type)
|
||||
assert.Equal(t, "hello back", assistantMsg.Content[0].Text)
|
||||
}
|
||||
|
||||
@@ -450,7 +450,7 @@ func encodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": resp.Model,
|
||||
"usage": resp.Usage,
|
||||
"usage": resp.Usage,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -45,9 +45,11 @@ func TestEncodeRequest_SystemInjection(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 2)
|
||||
firstMsg := msgs[0].(map[string]any)
|
||||
firstMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "system", firstMsg["role"])
|
||||
assert.Equal(t, "你是助手", firstMsg["content"])
|
||||
}
|
||||
@@ -72,12 +74,15 @@ func TestEncodeRequest_ToolCalls(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
assistantMsg := msgs[0].(map[string]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assistantMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
toolCalls, ok := assistantMsg["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, toolCalls, 1)
|
||||
tc := toolCalls[0].(map[string]any)
|
||||
tc, ok := toolCalls[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "call_1", tc["id"])
|
||||
}
|
||||
|
||||
@@ -100,11 +105,11 @@ func TestEncodeRequest_Thinking(t *testing.T) {
|
||||
func TestEncodeResponse_Basic(t *testing.T) {
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
@@ -115,9 +120,12 @@ func TestEncodeResponse_Basic(t *testing.T) {
|
||||
assert.Equal(t, "resp-1", result["id"])
|
||||
assert.Equal(t, "chat.completion", result["object"])
|
||||
|
||||
choices := result["choices"].([]any)
|
||||
choice := choices[0].(map[string]any)
|
||||
msg := choice["message"].(map[string]any)
|
||||
choices, ok := result["choices"].([]any)
|
||||
require.True(t, ok)
|
||||
choice, ok := choices[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
msg, ok := choice["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "你好", msg["content"])
|
||||
assert.Equal(t, "stop", choice["finish_reason"])
|
||||
}
|
||||
@@ -126,9 +134,9 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
sr := canonical.StopReasonToolUse
|
||||
input := json.RawMessage(`{"q":"test"}`)
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "resp-2",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
|
||||
ID: "resp-2",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
@@ -137,8 +145,12 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
msg := choices[0].(map[string]any)["message"].(map[string]any)
|
||||
choices, okc := result["choices"].([]any)
|
||||
require.True(t, okc)
|
||||
msgMap, okm := choices[0].(map[string]any)
|
||||
require.True(t, okm)
|
||||
msg, okmsg := msgMap["message"].(map[string]any)
|
||||
require.True(t, okmsg)
|
||||
tcs, ok := msg["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, tcs, 1)
|
||||
@@ -158,7 +170,8 @@ func TestEncodeModelsResponse(t *testing.T) {
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "list", result["object"])
|
||||
data := result["data"].([]any)
|
||||
data, okd := result["data"].([]any)
|
||||
require.True(t, okd)
|
||||
assert.Len(t, data, 2)
|
||||
}
|
||||
|
||||
@@ -317,8 +330,12 @@ func TestEncodeResponse_Thinking(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
msg := choices[0].(map[string]any)["message"].(map[string]any)
|
||||
choices, okch := result["choices"].([]any)
|
||||
require.True(t, okch)
|
||||
msgMap, okmm := choices[0].(map[string]any)
|
||||
require.True(t, okmm)
|
||||
msg, okmsg := msgMap["message"].(map[string]any)
|
||||
require.True(t, okmsg)
|
||||
assert.Equal(t, "回答", msg["content"])
|
||||
assert.Equal(t, "思考过程", msg["reasoning_content"])
|
||||
}
|
||||
|
||||
@@ -18,9 +18,9 @@ func TestStreamDecoder_BasicText(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"object": "chat.completion.chunk",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"object": "chat.completion.chunk",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -56,8 +56,8 @@ func TestStreamDecoder_ToolCalls(t *testing.T) {
|
||||
|
||||
idx := 0
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -98,8 +98,8 @@ func TestStreamDecoder_Thinking(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -127,8 +127,8 @@ func TestStreamDecoder_FinishReason(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -161,8 +161,8 @@ func TestStreamDecoder_DoneSignal(t *testing.T) {
|
||||
|
||||
// 先发送一个文本 chunk
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -190,8 +190,8 @@ func TestStreamDecoder_RefusalReuse(t *testing.T) {
|
||||
// 连续两个 refusal delta chunk
|
||||
for _, text := range []string{"拒绝", "原因"} {
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -250,8 +250,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
|
||||
|
||||
idx0 := 0
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -274,8 +274,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
|
||||
|
||||
idx1 := 1
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -322,8 +322,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -332,8 +332,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
|
||||
},
|
||||
}
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -358,8 +358,8 @@ func TestStreamDecoder_UTF8Truncation(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-utf8",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-utf8",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -390,8 +390,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
|
||||
|
||||
idx := 0
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-tc",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-tc",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -412,8 +412,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
|
||||
},
|
||||
}
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-tc",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-tc",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
|
||||
// StreamEncoder OpenAI 流式编码器
|
||||
type StreamEncoder struct {
|
||||
bufferedStart *canonical.CanonicalStreamEvent
|
||||
toolCallIndexMap map[string]int
|
||||
nextToolCallIndex int
|
||||
bufferedStart *canonical.CanonicalStreamEvent
|
||||
toolCallIndexMap map[string]int
|
||||
nextToolCallIndex int
|
||||
}
|
||||
|
||||
// NewStreamEncoder 创建 OpenAI 流式编码器
|
||||
@@ -195,8 +195,8 @@ func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent)
|
||||
func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte {
|
||||
chunk := map[string]any{
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
}},
|
||||
}
|
||||
return e.marshalChunk(chunk)
|
||||
|
||||
@@ -27,8 +27,12 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
|
||||
data := strings.TrimPrefix(s, "data: ")
|
||||
data = strings.TrimRight(data, "\n")
|
||||
require.NoError(t, json.Unmarshal([]byte(data), &payload))
|
||||
choices := payload["choices"].([]any)
|
||||
delta := choices[0].(map[string]any)["delta"].(map[string]any)
|
||||
choices, okch := payload["choices"].([]any)
|
||||
require.True(t, okch)
|
||||
msgMap, okmm := choices[0].(map[string]any)
|
||||
require.True(t, okmm)
|
||||
delta, okd := msgMap["delta"].(map[string]any)
|
||||
require.True(t, okd)
|
||||
assert.Equal(t, "assistant", delta["role"])
|
||||
}
|
||||
|
||||
|
||||
@@ -177,7 +177,8 @@ func TestEncodeRerankResponse(t *testing.T) {
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "rerank-1", result["model"])
|
||||
results := result["results"].([]any)
|
||||
results, okr := result["results"].([]any)
|
||||
require.True(t, okr)
|
||||
assert.Len(t, results, 1)
|
||||
}
|
||||
|
||||
@@ -356,9 +357,9 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
|
||||
reasoning := 20
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "r1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
|
||||
ID: "r1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{
|
||||
InputTokens: 100,
|
||||
@@ -373,7 +374,8 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage := result["usage"].(map[string]any)
|
||||
usage, oku := result["usage"].(map[string]any)
|
||||
require.True(t, oku)
|
||||
assert.Equal(t, float64(100), usage["prompt_tokens"])
|
||||
ptd, ok := usage["prompt_tokens_details"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
@@ -412,8 +414,10 @@ func TestEncodeResponse_StopReasons(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
choice := choices[0].(map[string]any)
|
||||
choices, okch := result["choices"].([]any)
|
||||
require.True(t, okch)
|
||||
choice, okc := choices[0].(map[string]any)
|
||||
require.True(t, okc)
|
||||
assert.Equal(t, tt.want, choice["finish_reason"])
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,42 +4,42 @@ import "encoding/json"
|
||||
|
||||
// ChatCompletionRequest OpenAI Chat Completion 请求
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
N *int `json:"n,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
Logprobs *bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
N *int `json:"n,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
Logprobs *bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
||||
|
||||
// 已废弃字段
|
||||
Functions []FunctionDef `json:"functions,omitempty"`
|
||||
FunctionCall any `json:"function_call,omitempty"`
|
||||
Functions []FunctionDef `json:"functions,omitempty"`
|
||||
FunctionCall any `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// Message OpenAI 消息
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
|
||||
// 已废弃
|
||||
FunctionCall *FunctionCallMsg `json:"function_call,omitempty"`
|
||||
@@ -88,8 +88,8 @@ type FunctionDef struct {
|
||||
|
||||
// ResponseFormat OpenAI 响应格式
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
|
||||
Type string `json:"type"`
|
||||
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
// JSONSchemaDef JSON Schema 定义
|
||||
@@ -118,7 +118,7 @@ type ChatCompletionResponse struct {
|
||||
|
||||
// Choice OpenAI 选择项
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
Index int `json:"index"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Delta *Message `json:"delta,omitempty"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
@@ -127,10 +127,10 @@ type Choice struct {
|
||||
|
||||
// Usage OpenAI 用量
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
|
||||
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package conversion
|
||||
|
||||
import "nex/backend/internal/conversion/canonical"
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamDecoder 流式解码器接口
|
||||
type StreamDecoder interface {
|
||||
@@ -38,14 +43,74 @@ func (c *PassthroughStreamConverter) Flush() [][]byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
|
||||
// 按 SSE frame 改写 data JSON 中的 model 字段
|
||||
type SmartPassthroughStreamConverter struct {
|
||||
adapter ProtocolAdapter
|
||||
modelOverride string
|
||||
interfaceType InterfaceType
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
|
||||
func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride string, interfaceType InterfaceType) *SmartPassthroughStreamConverter {
|
||||
return &SmartPassthroughStreamConverter{
|
||||
adapter: adapter,
|
||||
modelOverride: modelOverride,
|
||||
interfaceType: interfaceType,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 按 SSE frame 改写 data JSON 中的 model 字段
|
||||
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
if len(rawChunk) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.buffer = append(c.buffer, rawChunk...)
|
||||
frames, rest := splitSSEFrames(c.buffer)
|
||||
c.buffer = rest
|
||||
|
||||
result := make([][]byte, 0, len(frames))
|
||||
for _, frame := range frames {
|
||||
result = append(result, c.rewriteFrame(frame))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *SmartPassthroughStreamConverter) rewriteFrame(frame []byte) []byte {
|
||||
payload, ok := sseFrameDataPayload(frame)
|
||||
if !ok || strings.TrimSpace(payload) == "[DONE]" {
|
||||
return frame
|
||||
}
|
||||
|
||||
rewrittenPayload, err := c.adapter.RewriteResponseModelName([]byte(payload), c.modelOverride, c.interfaceType)
|
||||
if err != nil {
|
||||
return frame
|
||||
}
|
||||
|
||||
return rebuildSSEFrameWithData(frame, string(rewrittenPayload))
|
||||
}
|
||||
|
||||
// Flush 输出未形成完整 frame 的剩余数据
|
||||
func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
|
||||
if len(c.buffer) == 0 {
|
||||
return nil
|
||||
}
|
||||
frame := append([]byte(nil), c.buffer...)
|
||||
c.buffer = nil
|
||||
return [][]byte{c.rewriteFrame(frame)}
|
||||
}
|
||||
|
||||
// CanonicalStreamConverter 跨协议规范流式转换器
|
||||
type CanonicalStreamConverter struct {
|
||||
decoder StreamDecoder
|
||||
encoder StreamEncoder
|
||||
chain *MiddlewareChain
|
||||
ctx ConversionContext
|
||||
clientProtocol string
|
||||
decoder StreamDecoder
|
||||
encoder StreamEncoder
|
||||
chain *MiddlewareChain
|
||||
ctx ConversionContext
|
||||
clientProtocol string
|
||||
providerProtocol string
|
||||
modelOverride string
|
||||
}
|
||||
|
||||
// NewCanonicalStreamConverter 创建规范流式转换器
|
||||
@@ -57,18 +122,19 @@ func NewCanonicalStreamConverter(decoder StreamDecoder, encoder StreamEncoder) *
|
||||
}
|
||||
|
||||
// NewCanonicalStreamConverterWithMiddleware 创建带中间件的规范流式转换器
|
||||
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol string) *CanonicalStreamConverter {
|
||||
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol, modelOverride string) *CanonicalStreamConverter {
|
||||
return &CanonicalStreamConverter{
|
||||
decoder: decoder,
|
||||
encoder: encoder,
|
||||
chain: chain,
|
||||
ctx: ctx,
|
||||
clientProtocol: clientProtocol,
|
||||
decoder: decoder,
|
||||
encoder: encoder,
|
||||
chain: chain,
|
||||
ctx: ctx,
|
||||
clientProtocol: clientProtocol,
|
||||
providerProtocol: providerProtocol,
|
||||
modelOverride: modelOverride,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 解码 → 中间件 → 编码管道
|
||||
// ProcessChunk 解码 → 中间件 → modelOverride → 编码管道
|
||||
func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
events := c.decoder.ProcessChunk(rawChunk)
|
||||
var result [][]byte
|
||||
@@ -80,6 +146,7 @@ func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
}
|
||||
events[i] = *processed
|
||||
}
|
||||
c.applyModelOverride(&events[i])
|
||||
chunks := c.encoder.EncodeEvent(events[i])
|
||||
result = append(result, chunks...)
|
||||
}
|
||||
@@ -98,6 +165,7 @@ func (c *CanonicalStreamConverter) Flush() [][]byte {
|
||||
}
|
||||
events[i] = *processed
|
||||
}
|
||||
c.applyModelOverride(&events[i])
|
||||
chunks := c.encoder.EncodeEvent(events[i])
|
||||
result = append(result, chunks...)
|
||||
}
|
||||
@@ -105,3 +173,93 @@ func (c *CanonicalStreamConverter) Flush() [][]byte {
|
||||
result = append(result, encoderChunks...)
|
||||
return result
|
||||
}
|
||||
|
||||
// applyModelOverride 在跨协议场景下覆写流式事件中的 Model 字段
|
||||
func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.CanonicalStreamEvent) {
|
||||
if c.modelOverride != "" && event.Message != nil {
|
||||
event.Message.Model = c.modelOverride
|
||||
}
|
||||
}
|
||||
|
||||
func splitSSEFrames(data []byte) ([][]byte, []byte) {
|
||||
var frames [][]byte
|
||||
for len(data) > 0 {
|
||||
idx, sepLen := findSSEFrameSeparator(data)
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
end := idx + sepLen
|
||||
frames = append(frames, append([]byte(nil), data[:end]...))
|
||||
data = data[end:]
|
||||
}
|
||||
return frames, data
|
||||
}
|
||||
|
||||
func findSSEFrameSeparator(data []byte) (int, int) {
|
||||
lf := bytes.Index(data, []byte("\n\n"))
|
||||
crlf := bytes.Index(data, []byte("\r\n\r\n"))
|
||||
switch {
|
||||
case lf < 0 && crlf < 0:
|
||||
return -1, 0
|
||||
case lf < 0:
|
||||
return crlf, 4
|
||||
case crlf < 0:
|
||||
return lf, 2
|
||||
case crlf <= lf:
|
||||
return crlf, 4
|
||||
default:
|
||||
return lf, 2
|
||||
}
|
||||
}
|
||||
|
||||
func sseFrameDataPayload(frame []byte) (string, bool) {
|
||||
text := strings.TrimRight(string(frame), "\r\n")
|
||||
lines := strings.Split(text, "\n")
|
||||
var dataLines []string
|
||||
for _, line := range lines {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
value := strings.TrimPrefix(line, "data:")
|
||||
if strings.HasPrefix(value, " ") {
|
||||
value = value[1:]
|
||||
}
|
||||
dataLines = append(dataLines, value)
|
||||
}
|
||||
}
|
||||
if len(dataLines) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return strings.Join(dataLines, "\n"), true
|
||||
}
|
||||
|
||||
func rebuildSSEFrameWithData(frame []byte, data string) []byte {
|
||||
lineEnding, separator := sseLineEnding(frame)
|
||||
text := strings.TrimRight(string(frame), "\r\n")
|
||||
lines := strings.Split(text, "\n")
|
||||
out := make([]string, 0, len(lines)+1)
|
||||
dataWritten := false
|
||||
for _, line := range lines {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
if !dataWritten {
|
||||
for _, dataLine := range strings.Split(data, "\n") {
|
||||
out = append(out, "data: "+dataLine)
|
||||
}
|
||||
dataWritten = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
out = append(out, line)
|
||||
}
|
||||
if !dataWritten {
|
||||
out = append(out, "data: "+data)
|
||||
}
|
||||
return []byte(strings.Join(out, lineEnding) + separator)
|
||||
}
|
||||
|
||||
func sseLineEnding(frame []byte) (string, string) {
|
||||
if bytes.Contains(frame, []byte("\r\n")) {
|
||||
return "\r\n", "\r\n\r\n"
|
||||
}
|
||||
return "\n", "\n\n"
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ func TestCanonicalStreamConverter_WithMiddleware(t *testing.T) {
|
||||
chain.Use(&recordingMiddleware{name: "mw1", records: &records})
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
|
||||
result := converter.ProcessChunk([]byte("raw"))
|
||||
|
||||
assert.Len(t, result, 1)
|
||||
@@ -143,7 +143,7 @@ func TestCanonicalStreamConverter_MiddlewareError_Continue(t *testing.T) {
|
||||
chain.Use(&errorMiddleware{})
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
|
||||
result := converter.ProcessChunk([]byte("raw"))
|
||||
|
||||
assert.Nil(t, result, "middleware error should cause the event to be skipped (continue)")
|
||||
@@ -163,7 +163,7 @@ func TestCanonicalStreamConverter_Flush_MiddlewareError_Continue(t *testing.T) {
|
||||
chain.Use(&errorMiddleware{})
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
|
||||
result := converter.Flush()
|
||||
|
||||
assert.Len(t, result, 1)
|
||||
|
||||
132
backend/internal/database/database.go
Normal file
132
backend/internal/database/database.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/migrations"
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
|
||||
moduleLogger := pkglogger.WithModule(zapLogger, "database")
|
||||
|
||||
db, err := initDB(cfg, moduleLogger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("初始化数据库失败: %w", err)
|
||||
}
|
||||
|
||||
if err := runMigrations(db, cfg.Driver, moduleLogger); err != nil {
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
configurePool(db, cfg, moduleLogger)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func Close(db *gorm.DB) {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sqlDB.Close()
|
||||
}
|
||||
|
||||
func initDB(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
|
||||
gormLogger := pkglogger.NewGormLogger(zapLogger)
|
||||
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
}
|
||||
|
||||
switch cfg.Driver {
|
||||
case "mysql":
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName)
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("连接 MySQL 数据库",
|
||||
zap.String("host", cfg.Host),
|
||||
zap.Int("port", cfg.Port),
|
||||
zap.String("database", cfg.DBName))
|
||||
}
|
||||
return gorm.Open(mysql.Open(dsn), gormConfig)
|
||||
default:
|
||||
dbDir := filepath.Dir(cfg.Path)
|
||||
if err := os.MkdirAll(dbDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
|
||||
}
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("连接 SQLite 数据库", zap.String("path", cfg.Path))
|
||||
}
|
||||
return gorm.Open(sqlite.Open(cfg.Path), gormConfig)
|
||||
}
|
||||
}
|
||||
|
||||
func runMigrations(db *gorm.DB, driver string, zapLogger *zap.Logger) error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dialect, fsys, err := migrations.ForDriver(driver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("执行数据库迁移",
|
||||
zap.String("dialect", string(dialect)),
|
||||
zap.String("driver", driver))
|
||||
}
|
||||
|
||||
provider, err := goose.NewProvider(dialect, sqlDB, fsys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建迁移提供者失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := provider.Up(context.Background()); err != nil {
|
||||
return fmt.Errorf("执行迁移失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func configurePool(db *gorm.DB, cfg *config.DatabaseConfig, zapLogger *zap.Logger) {
|
||||
if cfg.Driver == "sqlite" {
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
if zapLogger != nil {
|
||||
zapLogger.Warn("启用 WAL 模式失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
|
||||
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("数据库连接池配置",
|
||||
zap.Int("max_idle_conns", cfg.MaxIdleConns),
|
||||
zap.Int("max_open_conns", cfg.MaxOpenConns),
|
||||
zap.Duration("conn_max_lifetime", cfg.ConnMaxLifetime))
|
||||
}
|
||||
}
|
||||
|
||||
func BuildDSN(cfg *config.DatabaseConfig) string {
|
||||
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName)
|
||||
}
|
||||
165
backend/internal/database/database_test.go
Normal file
165
backend/internal/database/database_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/migrations"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestInit_SQLite(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: filepath.Join(dir, "test.db"),
|
||||
MaxIdleConns: 5,
|
||||
MaxOpenConns: 10,
|
||||
ConnMaxLifetime: 0,
|
||||
}
|
||||
|
||||
zapLogger := zap.NewNop()
|
||||
db, err := Init(cfg, zapLogger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, db)
|
||||
defer Close(db)
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sqlDB)
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: filepath.Join(dir, "test.db"),
|
||||
MaxIdleConns: 5,
|
||||
MaxOpenConns: 10,
|
||||
ConnMaxLifetime: 0,
|
||||
}
|
||||
|
||||
zapLogger := zap.NewNop()
|
||||
db, err := Init(cfg, zapLogger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, db)
|
||||
|
||||
Close(db)
|
||||
}
|
||||
|
||||
func TestBuildDSN(t *testing.T) {
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: "mysql",
|
||||
Host: "db.example.com",
|
||||
Port: 3306,
|
||||
User: "nexuser",
|
||||
Password: "secretpass",
|
||||
DBName: "nexdb",
|
||||
}
|
||||
|
||||
dsn := BuildDSN(cfg)
|
||||
assert.Equal(t, "nexuser:secretpass@tcp(db.example.com:3306)/nexdb?charset=utf8mb4&parseTime=true&loc=Local", dsn)
|
||||
}
|
||||
|
||||
func TestBuildDSN_EmptyPassword(t *testing.T) {
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: "mysql",
|
||||
Host: "localhost",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
DBName: "nex",
|
||||
}
|
||||
|
||||
dsn := BuildDSN(cfg)
|
||||
assert.Equal(t, "root:@tcp(localhost:3306)/nex?charset=utf8mb4&parseTime=true&loc=Local", dsn)
|
||||
}
|
||||
|
||||
func TestInit_SQLite_AnyCWD(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
origDir, err := os.Getwd()
|
||||
if err == nil {
|
||||
defer func() {
|
||||
if chdirErr := os.Chdir(origDir); chdirErr != nil {
|
||||
t.Logf("无法恢复工作目录: %v", chdirErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
if chdirErr := os.Chdir(dir); chdirErr != nil {
|
||||
t.Skipf("无法切换到临时目录: %v", chdirErr)
|
||||
}
|
||||
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: filepath.Join(dir, "test.db"),
|
||||
MaxIdleConns: 5,
|
||||
MaxOpenConns: 10,
|
||||
ConnMaxLifetime: 0,
|
||||
}
|
||||
|
||||
zapLogger := zap.NewNop()
|
||||
db, err := Init(cfg, zapLogger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, db)
|
||||
defer Close(db)
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sqlDB)
|
||||
}
|
||||
|
||||
func TestForDriverDialect_SQLite(t *testing.T) {
|
||||
require.NoError(t, testMigrateWithDriver(t, "sqlite"))
|
||||
}
|
||||
|
||||
func TestForDriverDialect_MySQL(t *testing.T) {
|
||||
dialect, fsys, err := migrations.ForDriver("mysql")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "mysql", string(dialect))
|
||||
entries, fsErr := fs.ReadDir(fsys, ".")
|
||||
require.NoError(t, fsErr)
|
||||
assert.NotEmpty(t, entries, "MySQL 迁移资源应至少包含一个文件")
|
||||
}
|
||||
|
||||
func TestForDriverDialect_Invalid(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: "postgres",
|
||||
Path: filepath.Join(dir, "test.db"),
|
||||
MaxIdleConns: 5,
|
||||
MaxOpenConns: 10,
|
||||
ConnMaxLifetime: 0,
|
||||
}
|
||||
|
||||
zapLogger := zap.NewNop()
|
||||
_, err := Init(cfg, zapLogger)
|
||||
assert.Error(t, err, "非法 driver 应返回错误")
|
||||
assert.Contains(t, err.Error(), "不支持的数据库驱动")
|
||||
}
|
||||
|
||||
func testMigrateWithDriver(t *testing.T, driver string) error {
|
||||
t.Helper()
|
||||
|
||||
dir := t.TempDir()
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: driver,
|
||||
Path: filepath.Join(dir, "test.db"),
|
||||
MaxIdleConns: 5,
|
||||
MaxOpenConns: 10,
|
||||
ConnMaxLifetime: 0,
|
||||
}
|
||||
|
||||
zapLogger := zap.NewNop()
|
||||
db, err := Init(cfg, zapLogger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
Close(db)
|
||||
return nil
|
||||
}
|
||||
71
backend/internal/database/embedded_migration_test.go
Normal file
71
backend/internal/database/embedded_migration_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"testing"
|
||||
|
||||
"nex/backend/migrations"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEmbeddedMigrations_SQLiteResourcesPresent(t *testing.T) {
|
||||
entries, err := fs.ReadDir(migrations.FS, "sqlite")
|
||||
require.NoError(t, err)
|
||||
|
||||
var sqlFiles []string
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
sqlFiles = append(sqlFiles, entry.Name())
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, sqlFiles, "SQLite 迁移资源应至少包含一个 .sql 文件")
|
||||
}
|
||||
|
||||
func TestEmbeddedMigrations_MySQLResourcesPresent(t *testing.T) {
|
||||
entries, err := fs.ReadDir(migrations.FS, "mysql")
|
||||
require.NoError(t, err)
|
||||
|
||||
var sqlFiles []string
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
sqlFiles = append(sqlFiles, entry.Name())
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, sqlFiles, "MySQL 迁移资源应至少包含一个 .sql 文件")
|
||||
}
|
||||
|
||||
func TestEmbeddedMigrations_SQLiteSQLParsable(t *testing.T) {
|
||||
subFS, err := fs.Sub(migrations.FS, "sqlite")
|
||||
require.NoError(t, err)
|
||||
|
||||
entries, err := fs.ReadDir(subFS, ".")
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
data, err := fs.ReadFile(subFS, entry.Name())
|
||||
require.NoError(t, err, "无法读取迁移文件: %s", entry.Name())
|
||||
assert.NotEmpty(t, data, "迁移文件内容不应为空: %s", entry.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedMigrations_MySQLSQLParsable(t *testing.T) {
|
||||
subFS, err := fs.Sub(migrations.FS, "mysql")
|
||||
require.NoError(t, err)
|
||||
|
||||
entries, err := fs.ReadDir(subFS, ".")
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
data, err := fs.ReadFile(subFS, entry.Name())
|
||||
require.NoError(t, err, "无法读取迁移文件: %s", entry.Name())
|
||||
assert.NotEmpty(t, data, "迁移文件内容不应为空: %s", entry.Name())
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,12 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"time"
|
||||
|
||||
// Model 模型领域模型
|
||||
"nex/backend/pkg/modelid"
|
||||
)
|
||||
|
||||
// Model 模型领域模型(id 为 UUID 自动生成)
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
@@ -10,3 +14,8 @@ type Model struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// UnifiedModelID 返回统一模型 ID(格式:provider_id/model_name)
|
||||
func (m *Model) UnifiedModelID() string {
|
||||
return modelid.FormatUnifiedModelID(m.ProviderID, m.ModelName)
|
||||
}
|
||||
|
||||
@@ -13,12 +13,3 @@ type Provider struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// MaskAPIKey 掩码 API Key(仅显示最后 4 个字符)
|
||||
func (p *Provider) MaskAPIKey() {
|
||||
if len(p.APIKey) > 4 {
|
||||
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
|
||||
} else {
|
||||
p.APIKey = "***"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,20 +6,27 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/tests/mocks"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestProviderHandler_CreateProvider_Success(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
"name": "Test",
|
||||
"api_key": "sk-test",
|
||||
"id": "p1",
|
||||
"name": "Test",
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://api.test.com",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -33,11 +40,16 @@ func TestProviderHandler_CreateProvider_Success(t *testing.T) {
|
||||
var result domain.Provider
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "p1", result.ID)
|
||||
assert.Contains(t, result.APIKey, "***")
|
||||
assert.Equal(t, "sk-test", result.APIKey)
|
||||
}
|
||||
|
||||
func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
@@ -56,9 +68,13 @@ func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
provider: &domain.Provider{ID: "p1", Name: "Updated", APIKey: "***"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(nil)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("p1")).Return(&domain.Provider{ID: "p1", Name: "Updated", APIKey: "sk-test"}, nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -72,7 +88,11 @@ func TestProviderHandler_UpdateProvider(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -84,7 +104,12 @@ func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_DeleteProvider(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Delete(gomock.Eq("p1")).Return(nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -97,7 +122,12 @@ func TestProviderHandler_DeleteProvider(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_DeleteModel(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Delete(gomock.Eq("m1")).Return(nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -110,10 +140,17 @@ func TestModelHandler_DeleteModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_Success(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
|
||||
model.ID = "mock-uuid-1234"
|
||||
return nil
|
||||
})
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "m1",
|
||||
"provider_id": "p1",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
@@ -127,13 +164,16 @@ func TestModelHandler_CreateModel_Success(t *testing.T) {
|
||||
|
||||
var result domain.Model
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "m1", result.ID)
|
||||
assert.NotEmpty(t, result.ID)
|
||||
}
|
||||
|
||||
func TestModelHandler_GetModel(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
model: &domain.Model{ID: "m1", ModelName: "gpt-4"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -149,9 +189,13 @@ func TestModelHandler_GetModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_UpdateModel(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
model: &domain.Model{ID: "m1", ModelName: "gpt-4o"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4o"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"model_name": "gpt-4o"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -2,21 +2,22 @@ package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/tests/mocks"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/provider"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
@@ -24,82 +25,12 @@ func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// ============ Mock 实现 ============
|
||||
|
||||
type mockRoutingService struct {
|
||||
result *domain.RouteResult
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error) {
|
||||
return m.result, m.err
|
||||
}
|
||||
|
||||
type mockStatsService struct {
|
||||
err error
|
||||
stats []domain.UsageStats
|
||||
aggrResult []map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *mockStatsService) Record(providerID, modelName string) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
return m.stats, nil
|
||||
}
|
||||
func (m *mockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} {
|
||||
return m.aggrResult
|
||||
}
|
||||
|
||||
type mockProviderService struct {
|
||||
provider *domain.Provider
|
||||
providers []domain.Provider
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
|
||||
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
|
||||
return m.provider, m.err
|
||||
}
|
||||
func (m *mockProviderService) List() ([]domain.Provider, error) { return m.providers, m.err }
|
||||
func (m *mockProviderService) Update(id string, updates map[string]interface{}) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockProviderService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockModelService struct {
|
||||
model *domain.Model
|
||||
models []domain.Model
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockModelService) Create(model *domain.Model) error { return m.err }
|
||||
func (m *mockModelService) Get(id string) (*domain.Model, error) {
|
||||
return m.model, m.err
|
||||
}
|
||||
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
|
||||
return m.models, m.err
|
||||
}
|
||||
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockModelService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockProviderClient struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderClient) Send(ctx context.Context, spec interface{}) (interface{}, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
func (m *mockProviderClient) SendStream(ctx context.Context, spec interface{}) (<-chan provider.StreamEvent, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
// ============ Provider Handler 测试 ============
|
||||
|
||||
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"id": "p1"})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -112,12 +43,15 @@ func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_ListProviders(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
providers: []domain.Provider{
|
||||
{ID: "p1", Name: "P1"},
|
||||
{ID: "p2", Name: "P2"},
|
||||
},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().List().Return([]domain.Provider{
|
||||
{ID: "p1", Name: "P1"},
|
||||
{ID: "p2", Name: "P2"},
|
||||
}, nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -127,14 +61,17 @@ func TestProviderHandler_ListProviders(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result []domain.Provider
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Len(t, result, 2)
|
||||
}
|
||||
|
||||
func TestProviderHandler_GetProvider(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
provider: &domain.Provider{ID: "p1", Name: "P1", APIKey: "***"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("p1")).Return(&domain.Provider{ID: "p1", Name: "P1", APIKey: "sk-test"}, nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -145,10 +82,12 @@ func TestProviderHandler_GetProvider(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ Model Handler 测试 ============
|
||||
|
||||
func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"id": "m1"})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -161,12 +100,15 @@ func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_ListModels(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
models: []domain.Model{
|
||||
{ID: "m1", ModelName: "gpt-4"},
|
||||
{ID: "m2", ModelName: "gpt-3.5"},
|
||||
},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().List(gomock.Eq("")).Return([]domain.Model{
|
||||
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
|
||||
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"},
|
||||
}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -174,16 +116,98 @@ func TestModelHandler_ListModels(t *testing.T) {
|
||||
|
||||
h.ListModels(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result []modelResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
require.Len(t, result, 2)
|
||||
assert.Equal(t, "openai/gpt-4", result[0].UnifiedModelID)
|
||||
assert.Equal(t, "anthropic/claude-3", result[1].UnifiedModelID)
|
||||
}
|
||||
|
||||
// ============ Stats Handler 测试 ============
|
||||
func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "m1"}}
|
||||
c.Request = httptest.NewRequest("GET", "/api/models/m1", nil)
|
||||
|
||||
h.GetModel(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result modelResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "m1", result.ID)
|
||||
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
|
||||
model.ID = "mock-uuid-1234"
|
||||
return nil
|
||||
})
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 201, w.Code)
|
||||
|
||||
var result modelResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "mock-uuid-1234", result.ID)
|
||||
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
|
||||
}
|
||||
|
||||
func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"enabled": false})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "m1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/models/m1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateModel(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result modelResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "openai/gpt-4-turbo", result.UnifiedModelID)
|
||||
}
|
||||
|
||||
func TestStatsHandler_GetStats(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockStatsService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
}, nil)
|
||||
h := NewStatsHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -194,7 +218,11 @@ func TestStatsHandler_GetStats(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := NewStatsHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -205,14 +233,17 @@ func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStatsHandler_AggregateStats(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
},
|
||||
aggrResult: []map[string]interface{}{
|
||||
{"provider_id": "p1", "request_count": 10},
|
||||
},
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockStatsService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
}, nil)
|
||||
mockSvc.EXPECT().Aggregate(gomock.Any(), gomock.Eq("provider")).Return([]map[string]interface{}{
|
||||
{"provider_id": "p1", "request_count": 10},
|
||||
})
|
||||
h := NewStatsHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -222,8 +253,6 @@ func TestStatsHandler_AggregateStats(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ writeError 测试 ============
|
||||
|
||||
func TestWriteError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -252,12 +281,13 @@ func formatMapErrors(errs map[string]string) string {
|
||||
return "请求验证失败: " + strings.Join(parts, "; ")
|
||||
}
|
||||
|
||||
// ============ 错误类型判断测试 ============
|
||||
|
||||
func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
err: gorm.ErrDuplicatedKey,
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrConflict)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
@@ -273,3 +303,158 @@ func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
|
||||
h.CreateProvider(c)
|
||||
assert.Equal(t, 409, w.Code)
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_ProviderNotFound(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrProviderNotFound)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "nonexistent",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "供应商不存在")
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_DuplicateModel(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrDuplicateModel)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 409, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "同一供应商下模型名称已存在")
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_InternalError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(errors.New("database error"))
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_NotFound(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(gorm.ErrRecordNotFound)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_ImmutableField(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(appErrors.ErrImmutableField)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "供应商 ID 不允许修改")
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_InternalError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(errors.New("database error"))
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_InvalidJSON(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader([]byte("{invalid json")))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_CreateProvider_InvalidJSON(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader([]byte("{invalid json")))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateProvider(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
@@ -5,9 +5,10 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// Logging 日志中间件
|
||||
func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
@@ -15,12 +16,16 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||
query := c.Request.URL.RawQuery
|
||||
|
||||
requestID, _ := c.Get(RequestIDKey)
|
||||
logger.Info("请求开始",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", path),
|
||||
zap.String("query", query),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
zap.Any("request_id", requestID),
|
||||
var requestIDStr string
|
||||
if id, ok := requestID.(string); ok {
|
||||
requestIDStr = id
|
||||
}
|
||||
logger.Debug("请求开始",
|
||||
pkglogger.Method(c.Request.Method),
|
||||
pkglogger.Path(path),
|
||||
pkglogger.Query(query),
|
||||
pkglogger.ClientIP(c.ClientIP()),
|
||||
pkglogger.RequestID(requestIDStr),
|
||||
)
|
||||
|
||||
c.Next()
|
||||
@@ -28,13 +33,13 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||
latency := time.Since(start)
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
logger.Info("请求结束",
|
||||
zap.Int("status", statusCode),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", path),
|
||||
zap.Duration("latency", latency),
|
||||
zap.Int("body_size", c.Writer.Size()),
|
||||
zap.Any("request_id", requestID),
|
||||
logger.Debug("请求结束",
|
||||
pkglogger.StatusCode(statusCode),
|
||||
pkglogger.Method(c.Request.Method),
|
||||
pkglogger.Path(path),
|
||||
pkglogger.Latency(latency),
|
||||
pkglogger.BodySize(c.Writer.Size()),
|
||||
pkglogger.RequestID(requestIDStr),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"go.uber.org/zap/zaptest/observer"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -65,6 +67,61 @@ func TestLogging(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestLogging_DoesNotLogLifecycleAtInfoLevel(t *testing.T) {
|
||||
core, logs := observer.New(zapcore.InfoLevel)
|
||||
logger := zap.New(core)
|
||||
|
||||
w := serveLoggingRequest(logger)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Empty(t, logs.FilterMessage("请求开始").All())
|
||||
assert.Empty(t, logs.FilterMessage("请求结束").All())
|
||||
}
|
||||
|
||||
func TestLogging_LogsLifecycleAtDebugLevel(t *testing.T) {
|
||||
core, logs := observer.New(zapcore.DebugLevel)
|
||||
logger := zap.New(core)
|
||||
|
||||
w := serveLoggingRequest(logger)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
startLogs := logs.FilterMessage("请求开始").All()
|
||||
endLogs := logs.FilterMessage("请求结束").All()
|
||||
if assert.Len(t, startLogs, 1) {
|
||||
fields := startLogs[0].ContextMap()
|
||||
assert.Equal(t, "GET", fields["method"])
|
||||
assert.Equal(t, "/test", fields["path"])
|
||||
assert.Equal(t, "key=value", fields["query"])
|
||||
assert.Equal(t, "existing-id-123", fields["request_id"])
|
||||
assert.NotEmpty(t, fields["client_ip"])
|
||||
}
|
||||
if assert.Len(t, endLogs, 1) {
|
||||
fields := endLogs[0].ContextMap()
|
||||
assert.Equal(t, int64(200), fields["status"])
|
||||
assert.Equal(t, "GET", fields["method"])
|
||||
assert.Equal(t, "/test", fields["path"])
|
||||
assert.Equal(t, int64(2), fields["body_size"])
|
||||
assert.Equal(t, "existing-id-123", fields["request_id"])
|
||||
assert.Contains(t, fields, "latency")
|
||||
}
|
||||
}
|
||||
|
||||
func serveLoggingRequest(logger *zap.Logger) *httptest.ResponseRecorder {
|
||||
r := gin.New()
|
||||
r.Use(RequestID())
|
||||
r.Use(Logging(logger))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(200, "ok")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/test?key=value", nil)
|
||||
req.Header.Set("X-Request-ID", "existing-id-123")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
func TestRecovery_NoPanic(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ModelHandler 模型管理处理器
|
||||
@@ -22,40 +23,59 @@ func NewModelHandler(modelService service.ModelService) *ModelHandler {
|
||||
return &ModelHandler{modelService: modelService}
|
||||
}
|
||||
|
||||
// modelResponse 模型响应 DTO,扩展 unified_id 字段
|
||||
type modelResponse struct {
|
||||
domain.Model
|
||||
UnifiedModelID string `json:"unified_id"`
|
||||
}
|
||||
|
||||
// newModelResponse 从 domain.Model 构造响应 DTO
|
||||
func newModelResponse(m *domain.Model) modelResponse {
|
||||
return modelResponse{
|
||||
Model: *m,
|
||||
UnifiedModelID: m.UnifiedModelID(),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateModel 创建模型
|
||||
func (h *ModelHandler) CreateModel(c *gin.Context) {
|
||||
var req struct {
|
||||
ID string `json:"id" binding:"required"`
|
||||
ProviderID string `json:"provider_id" binding:"required"`
|
||||
ModelName string `json:"model_name" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "缺少必需字段: id, provider_id, model_name",
|
||||
"error": "缺少必需字段: provider_id, model_name",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
model := &domain.Model{
|
||||
ID: req.ID,
|
||||
ProviderID: req.ProviderID,
|
||||
ModelName: req.ModelName,
|
||||
}
|
||||
|
||||
err := h.modelService.Create(model)
|
||||
if err != nil {
|
||||
if err == appErrors.ErrProviderNotFound {
|
||||
if errors.Is(err, appErrors.ErrProviderNotFound) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, appErrors.ErrDuplicateModel) {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "同一供应商下模型名称已存在",
|
||||
"code": appErrors.ErrDuplicateModel.Code,
|
||||
})
|
||||
return
|
||||
}
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, model)
|
||||
c.JSON(http.StatusCreated, newModelResponse(model))
|
||||
}
|
||||
|
||||
// ListModels 列出模型
|
||||
@@ -68,7 +88,11 @@ func (h *ModelHandler) ListModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models)
|
||||
resp := make([]modelResponse, len(models))
|
||||
for i, m := range models {
|
||||
resp[i] = newModelResponse(&m)
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// GetModel 获取模型
|
||||
@@ -77,7 +101,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
|
||||
|
||||
model, err := h.modelService.Get(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
@@ -87,7 +111,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model)
|
||||
c.JSON(http.StatusOK, newModelResponse(model))
|
||||
}
|
||||
|
||||
// UpdateModel 更新模型
|
||||
@@ -104,18 +128,25 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
||||
|
||||
err := h.modelService.Update(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, appErrors.ErrModelNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err == appErrors.ErrProviderNotFound {
|
||||
if errors.Is(err, appErrors.ErrProviderNotFound) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, appErrors.ErrDuplicateModel) {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": appErrors.ErrDuplicateModel.Message,
|
||||
"code": appErrors.ErrDuplicateModel.Code,
|
||||
})
|
||||
return
|
||||
}
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
@@ -126,7 +157,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model)
|
||||
c.JSON(http.StatusOK, newModelResponse(model))
|
||||
}
|
||||
|
||||
// DeleteModel 删除模型
|
||||
@@ -135,7 +166,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
|
||||
|
||||
err := h.modelService.Delete(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
|
||||
@@ -4,13 +4,13 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ProviderHandler 供应商管理处理器
|
||||
@@ -55,9 +55,10 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
|
||||
err := h.providerService.Create(provider)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "供应商 ID 已存在",
|
||||
if errors.Is(err, appErrors.ErrInvalidProviderID) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": appErrors.ErrInvalidProviderID.Message,
|
||||
"code": appErrors.ErrInvalidProviderID.Code,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -65,7 +66,6 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
provider.MaskAPIKey()
|
||||
c.JSON(http.StatusCreated, provider)
|
||||
}
|
||||
|
||||
@@ -84,9 +84,9 @@ func (h *ProviderHandler) ListProviders(c *gin.Context) {
|
||||
func (h *ProviderHandler) GetProvider(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
provider, err := h.providerService.Get(id, true)
|
||||
provider, err := h.providerService.Get(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
@@ -113,17 +113,24 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
||||
|
||||
err := h.providerService.Update(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, appErrors.ErrImmutableField) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": appErrors.ErrImmutableField.Message,
|
||||
"code": appErrors.ErrImmutableField.Code,
|
||||
})
|
||||
return
|
||||
}
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := h.providerService.Get(id, true)
|
||||
provider, err := h.providerService.Get(id)
|
||||
if err != nil {
|
||||
writeError(c, err)
|
||||
return
|
||||
@@ -138,7 +145,7 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
|
||||
|
||||
err := h.providerService.Delete(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
|
||||
@@ -3,38 +3,44 @@ package handler
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/service"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
"nex/backend/pkg/modelid"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// ProxyHandler 统一代理处理器
|
||||
type ProxyHandler struct {
|
||||
engine *conversion.ConversionEngine
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
engine *conversion.ConversionEngine
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
providerService service.ProviderService
|
||||
statsService service.StatsService
|
||||
logger *zap.Logger
|
||||
statsService service.StatsService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewProxyHandler 创建统一代理处理器
|
||||
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService) *ProxyHandler {
|
||||
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService, logger *zap.Logger) *ProxyHandler {
|
||||
return &ProxyHandler{
|
||||
engine: engine,
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
providerService: providerService,
|
||||
statsService: statsService,
|
||||
logger: zap.L(),
|
||||
logger: pkglogger.WithModule(logger, "handler.proxy"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,47 +49,93 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
// 从 URL 提取 clientProtocol: /{protocol}/v1/...
|
||||
clientProtocol := c.Param("protocol")
|
||||
if clientProtocol == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"})
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "缺少协议前缀")
|
||||
return
|
||||
}
|
||||
|
||||
// 原始路径: /v1/{path}
|
||||
// 原始路径: /{path}
|
||||
path := c.Param("path")
|
||||
if strings.HasPrefix(path, "/") {
|
||||
path = path[1:]
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
nativePath := path
|
||||
requestPath := appendRawQuery(nativePath, c.Request.URL.RawQuery)
|
||||
|
||||
// 获取 client adapter
|
||||
registry := h.engine.GetRegistry()
|
||||
clientAdapter, err := registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 检测接口类型
|
||||
ifaceType := clientAdapter.DetectInterfaceType(nativePath)
|
||||
|
||||
// 处理 Models 接口:本地聚合
|
||||
if ifaceType == conversion.InterfaceTypeModels {
|
||||
h.handleModelsList(c, clientAdapter)
|
||||
return
|
||||
}
|
||||
|
||||
// 处理 ModelInfo 接口:本地查询
|
||||
if ifaceType == conversion.InterfaceTypeModelInfo {
|
||||
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
|
||||
if err != nil {
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的模型 ID 格式")
|
||||
return
|
||||
}
|
||||
h.handleModelInfo(c, unifiedID, clientAdapter)
|
||||
return
|
||||
}
|
||||
nativePath := "/v1/" + path
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "读取请求体失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 model 名称(从 JSON body 中提取,GET 请求无 body)
|
||||
modelName := ""
|
||||
if len(body) > 0 {
|
||||
modelName = extractModelName(body)
|
||||
}
|
||||
|
||||
// 构建输入 HTTPRequestSpec
|
||||
inSpec := conversion.HTTPRequestSpec{
|
||||
URL: nativePath,
|
||||
URL: requestPath,
|
||||
Method: c.Request.Method,
|
||||
Headers: extractHeaders(c),
|
||||
Body: body,
|
||||
}
|
||||
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
|
||||
|
||||
// 路由
|
||||
routeResult, err := h.routingService.Route(modelName)
|
||||
// 只有 adapter 明确适配的接口才提取 model。未知接口不做通用 model 猜测。
|
||||
if len(body) == 0 || !supportsModelExtraction(ifaceType) {
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
|
||||
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
|
||||
if err != nil {
|
||||
// GET 请求或无法提取 model 时,直接转发到上游
|
||||
if len(body) == 0 || modelName == "" {
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol)
|
||||
if isInvalidJSONError(err) {
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误")
|
||||
return
|
||||
}
|
||||
h.writeError(c, err, clientProtocol)
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
|
||||
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||
if err != nil {
|
||||
// 原始模型名兼容透传:非统一模型 ID 不参与路由。
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
if providerID == "" || modelName == "" {
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
|
||||
// 路由
|
||||
routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
|
||||
if err != nil {
|
||||
h.writeRouteError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -94,28 +146,53 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 构建 TargetProvider
|
||||
// 注意:ModelName 字段用于 Smart Passthrough 场景改写请求体
|
||||
// 同协议:请求体中的统一 ID 会被改写为 ModelName(上游名)
|
||||
// 跨协议:全量转换时 ModelName 会被编码到请求体中
|
||||
targetProvider := conversion.NewTargetProvider(
|
||||
routeResult.Provider.BaseURL,
|
||||
routeResult.Provider.APIKey,
|
||||
routeResult.Model.ModelName,
|
||||
routeResult.Model.ModelName, // 上游模型名,用于请求改写
|
||||
)
|
||||
|
||||
// 判断是否流式
|
||||
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
|
||||
// 计算统一模型 ID(用于响应覆写)
|
||||
unifiedModelID := routeResult.Model.UnifiedModelID()
|
||||
|
||||
if isStream {
|
||||
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
|
||||
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
|
||||
} else {
|
||||
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
|
||||
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
|
||||
}
|
||||
}
|
||||
|
||||
func supportsModelExtraction(ifaceType conversion.InterfaceType) bool {
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isInvalidJSONError(err error) bool {
|
||||
var syntaxErr *json.SyntaxError
|
||||
var typeErr *json.UnmarshalTypeError
|
||||
return errors.As(err, &syntaxErr) || errors.As(err, &typeErr)
|
||||
}
|
||||
|
||||
func appendRawQuery(path, rawQuery string) string {
|
||||
if rawQuery == "" {
|
||||
return path
|
||||
}
|
||||
return path + "?" + rawQuery
|
||||
}
|
||||
|
||||
// handleNonStream 处理非流式请求
|
||||
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
|
||||
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
|
||||
// 转换请求
|
||||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||
if err != nil {
|
||||
h.logger.Error("转换请求失败", zap.String("error", err.Error()))
|
||||
h.logger.Error("转换请求失败", zap.Error(err))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
@@ -123,37 +200,32 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
|
||||
// 发送请求
|
||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.logger.Error("发送请求失败", zap.String("error", err.Error()))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
h.logger.Error("发送请求失败", zap.Error(err))
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, *resp)
|
||||
return
|
||||
}
|
||||
|
||||
// 转换响应
|
||||
interfaceType, _ := h.engine.DetectInterfaceType(inSpec.URL, clientProtocol)
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, interfaceType)
|
||||
// 转换响应,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
|
||||
if err != nil {
|
||||
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
|
||||
h.logger.Error("转换响应失败", zap.Error(err))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
for k, v := range convertedResp.Headers {
|
||||
c.Header(k, v)
|
||||
}
|
||||
if c.GetHeader("Content-Type") == "" {
|
||||
c.Header("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
||||
h.writeConvertedResponse(c, *convertedResp)
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||||
}()
|
||||
}
|
||||
|
||||
// handleStream 处理流式请求
|
||||
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
|
||||
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
|
||||
// 转换请求
|
||||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||
if err != nil {
|
||||
@@ -161,15 +233,23 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
|
||||
return
|
||||
}
|
||||
|
||||
// 创建流式转换器
|
||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol)
|
||||
// 发送流式请求
|
||||
streamResp, err := h.client.SendStream(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
|
||||
StatusCode: streamResp.StatusCode,
|
||||
Headers: streamResp.Headers,
|
||||
Body: streamResp.Body,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 发送流式请求
|
||||
eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec)
|
||||
// 创建流式转换器,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
@@ -180,37 +260,61 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
flushed := false
|
||||
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
if event.Error != nil {
|
||||
h.logger.Error("流读取错误", zap.String("error", event.Error.Error()))
|
||||
h.logger.Error("流读取错误", zap.Error(event.Error))
|
||||
break
|
||||
}
|
||||
if event.Done {
|
||||
// flush 转换器
|
||||
chunks := streamConverter.Flush()
|
||||
for _, chunk := range chunks {
|
||||
writer.Write(chunk)
|
||||
writer.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
flushed = true
|
||||
break
|
||||
}
|
||||
|
||||
chunks := streamConverter.ProcessChunk(event.Data)
|
||||
for _, chunk := range chunks {
|
||||
writer.Write(chunk)
|
||||
writer.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||
break
|
||||
}
|
||||
}
|
||||
if !flushed {
|
||||
chunks := streamConverter.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error {
|
||||
for _, chunk := range chunks {
|
||||
if _, err := writer.Write(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writer.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isStreamRequest 判断是否流式请求
|
||||
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
|
||||
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol)
|
||||
ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if ifaceType != conversion.InterfaceTypeChat {
|
||||
return false
|
||||
}
|
||||
@@ -224,34 +328,166 @@ func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath s
|
||||
return req.Stream
|
||||
}
|
||||
|
||||
// writeConversionError 写入转换错误
|
||||
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
|
||||
if convErr, ok := err.(*conversion.ConversionError); ok {
|
||||
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol)
|
||||
c.Data(statusCode, "application/json", body)
|
||||
// handleModelsList 处理 GET /v1/models 本地聚合
|
||||
func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.ProtocolAdapter) {
|
||||
// 从数据库查询所有启用的模型
|
||||
models, err := h.providerService.ListEnabledModels()
|
||||
if err != nil {
|
||||
h.logger.Error("查询启用模型失败", zap.Error(err))
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "查询模型失败")
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
|
||||
// 构建 CanonicalModelList
|
||||
modelList := &canonical.CanonicalModelList{
|
||||
Models: make([]canonical.CanonicalModel, 0, len(models)),
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
modelList.Models = append(modelList.Models, canonical.CanonicalModel{
|
||||
ID: m.UnifiedModelID(),
|
||||
Name: m.ModelName,
|
||||
Created: m.CreatedAt.Unix(),
|
||||
OwnedBy: m.ProviderID,
|
||||
})
|
||||
}
|
||||
|
||||
// 使用 adapter 编码返回
|
||||
body, err := adapter.EncodeModelsResponse(modelList)
|
||||
if err != nil {
|
||||
h.logger.Error("编码 Models 响应失败", zap.Error(err))
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "application/json", body)
|
||||
}
|
||||
|
||||
// writeError 写入路由错误
|
||||
func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
// handleModelInfo 处理 GET /v1/models/{unified_id} 本地查询
|
||||
func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter conversion.ProtocolAdapter) {
|
||||
// 解析统一模型 ID
|
||||
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||
if err != nil {
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的统一模型 ID 格式")
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库查询模型
|
||||
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
|
||||
if err != nil {
|
||||
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", "模型未找到")
|
||||
return
|
||||
}
|
||||
|
||||
// 构建 CanonicalModelInfo
|
||||
modelInfo := &canonical.CanonicalModelInfo{
|
||||
ID: model.UnifiedModelID(),
|
||||
Name: model.ModelName,
|
||||
Created: model.CreatedAt.Unix(),
|
||||
OwnedBy: model.ProviderID,
|
||||
}
|
||||
|
||||
// 使用 adapter 编码返回
|
||||
body, err := adapter.EncodeModelInfoResponse(modelInfo)
|
||||
if err != nil {
|
||||
h.logger.Error("编码 ModelInfo 响应失败", zap.Error(err))
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "application/json", body)
|
||||
}
|
||||
|
||||
// writeConversionError 写入网关层转换错误
|
||||
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
|
||||
var convErr *conversion.ConversionError
|
||||
if errors.As(err, &convErr) {
|
||||
statusCode, code, message := mapConversionError(convErr)
|
||||
h.writeProxyError(c, statusCode, code, message)
|
||||
return
|
||||
}
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", err.Error())
|
||||
}
|
||||
|
||||
func mapConversionError(err *conversion.ConversionError) (int, string, string) {
|
||||
switch err.Code {
|
||||
case conversion.ErrorCodeJSONParseError:
|
||||
if phase, ok := err.Details[conversion.ErrorDetailPhase].(string); ok && phase == conversion.ErrorPhaseRequest {
|
||||
return http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误"
|
||||
}
|
||||
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
|
||||
case conversion.ErrorCodeInvalidInput,
|
||||
conversion.ErrorCodeMissingRequiredField,
|
||||
conversion.ErrorCodeProtocolConstraint:
|
||||
return http.StatusBadRequest, "INVALID_REQUEST", err.Message
|
||||
case conversion.ErrorCodeInterfaceNotSupported:
|
||||
return http.StatusBadRequest, "UNSUPPORTED_INTERFACE", err.Message
|
||||
case conversion.ErrorCodeUnsupportedMultimodal:
|
||||
return http.StatusBadRequest, "UNSUPPORTED_MULTIMODAL", err.Message
|
||||
default:
|
||||
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeRouteError(c *gin.Context, err error) {
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
switch appErr.Code {
|
||||
case appErrors.ErrModelNotFound.Code, appErrors.ErrModelDisabled.Code:
|
||||
h.writeProxyError(c, appErr.HTTPStatus, "MODEL_NOT_FOUND", appErr.Message)
|
||||
case appErrors.ErrProviderNotFound.Code, appErrors.ErrProviderDisabled.Code:
|
||||
h.writeProxyError(c, appErr.HTTPStatus, "PROVIDER_NOT_FOUND", appErr.Message)
|
||||
default:
|
||||
h.writeProxyError(c, appErr.HTTPStatus, "INVALID_REQUEST", appErr.Message)
|
||||
}
|
||||
return
|
||||
}
|
||||
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", err.Error())
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeUpstreamUnavailable(c *gin.Context, err error) {
|
||||
h.logger.Error("上游不可达", zap.Error(err))
|
||||
h.writeProxyError(c, http.StatusBadGateway, "UPSTREAM_UNAVAILABLE", "上游服务不可达")
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeProxyError(c *gin.Context, status int, code, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": message,
|
||||
"code": code,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeConvertedResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
|
||||
for k, v := range resp.Headers {
|
||||
c.Header(k, v)
|
||||
}
|
||||
contentType := headerValue(resp.Headers, "Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, resp.Body)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeUpstreamResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
|
||||
for k, v := range filterHopByHopHeaders(resp.Headers) {
|
||||
c.Header(k, v)
|
||||
}
|
||||
contentType := headerValue(resp.Headers, "Content-Type")
|
||||
c.Data(resp.StatusCode, contentType, resp.Body)
|
||||
}
|
||||
|
||||
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
|
||||
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) {
|
||||
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string, ifaceType conversion.InterfaceType, isStream bool) {
|
||||
registry := h.engine.GetRegistry()
|
||||
adapter, err := registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
|
||||
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.providerService.List()
|
||||
if err != nil || len(providers) == 0 {
|
||||
h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"})
|
||||
h.logger.Warn("无可用供应商转发请求", zap.String("path", inSpec.URL))
|
||||
h.writeProxyError(c, http.StatusNotFound, "PROVIDER_NOT_FOUND", "没有可用的供应商。请先创建供应商和模型。")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -261,19 +497,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
|
||||
providerProtocol = "openai"
|
||||
}
|
||||
|
||||
ifaceType := adapter.DetectInterfaceType(inSpec.URL)
|
||||
|
||||
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
|
||||
|
||||
var outSpec *conversion.HTTPRequestSpec
|
||||
if clientProtocol == providerProtocol {
|
||||
upstreamURL := p.BaseURL + inSpec.URL
|
||||
upstreamPath := adapter.BuildUrl(stripRawQuery(inSpec.URL), ifaceType)
|
||||
upstreamPath = appendRawQuery(upstreamPath, rawQueryFromPath(inSpec.URL))
|
||||
headers := adapter.BuildHeaders(targetProvider)
|
||||
if _, ok := headers["Content-Type"]; !ok {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
outSpec = &conversion.HTTPRequestSpec{
|
||||
URL: upstreamURL,
|
||||
URL: joinBaseURL(p.BaseURL, upstreamPath),
|
||||
Method: inSpec.Method,
|
||||
Headers: headers,
|
||||
Body: inSpec.Body,
|
||||
@@ -286,36 +521,132 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
|
||||
}
|
||||
}
|
||||
|
||||
if isStream {
|
||||
h.forwardStream(c, *outSpec, clientProtocol, providerProtocol, ifaceType)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, *resp)
|
||||
return
|
||||
}
|
||||
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType)
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, "")
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range convertedResp.Headers {
|
||||
c.Header(k, v)
|
||||
}
|
||||
if c.GetHeader("Content-Type") == "" {
|
||||
c.Header("Content-Type", "application/json")
|
||||
}
|
||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
||||
h.writeConvertedResponse(c, *convertedResp)
|
||||
}
|
||||
|
||||
// extractModelName 从 JSON body 中提取 model
|
||||
func extractModelName(body []byte) string {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
func (h *ProxyHandler) forwardStream(c *gin.Context, outSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, ifaceType conversion.InterfaceType) {
|
||||
streamResp, err := h.client.SendStream(c.Request.Context(), outSpec)
|
||||
if err != nil {
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
|
||||
StatusCode: streamResp.StatusCode,
|
||||
Headers: streamResp.Headers,
|
||||
Body: streamResp.Body,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, "", ifaceType)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
flushed := false
|
||||
for event := range streamResp.Events {
|
||||
if event.Error != nil {
|
||||
h.logger.Error("透传流读取错误", zap.Error(event.Error))
|
||||
break
|
||||
}
|
||||
if event.Done {
|
||||
chunks := streamConverter.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
flushed = true
|
||||
break
|
||||
}
|
||||
chunks := streamConverter.ProcessChunk(event.Data)
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||||
break
|
||||
}
|
||||
}
|
||||
if !flushed {
|
||||
chunks := streamConverter.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stripRawQuery(path string) string {
|
||||
pathOnly, _, _ := strings.Cut(path, "?")
|
||||
return pathOnly
|
||||
}
|
||||
|
||||
func rawQueryFromPath(path string) string {
|
||||
_, rawQuery, found := strings.Cut(path, "?")
|
||||
if !found {
|
||||
return ""
|
||||
}
|
||||
return req.Model
|
||||
return rawQuery
|
||||
}
|
||||
|
||||
func joinBaseURL(baseURL, path string) string {
|
||||
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
|
||||
}
|
||||
|
||||
func headerValue(headers map[string]string, key string) string {
|
||||
for k, v := range headers {
|
||||
if strings.EqualFold(k, key) {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func filterHopByHopHeaders(headers map[string]string) map[string]string {
|
||||
if len(headers) == 0 {
|
||||
return nil
|
||||
}
|
||||
hopByHop := map[string]struct{}{
|
||||
"connection": {},
|
||||
"transfer-encoding": {},
|
||||
"keep-alive": {},
|
||||
"proxy-authenticate": {},
|
||||
"proxy-authorization": {},
|
||||
"te": {},
|
||||
"trailer": {},
|
||||
"upgrade": {},
|
||||
}
|
||||
filtered := make(map[string]string, len(headers))
|
||||
for k, v := range headers {
|
||||
if _, skip := hopByHop[strings.ToLower(k)]; skip {
|
||||
continue
|
||||
}
|
||||
filtered[k] = v
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// extractHeaders 从 Gin context 提取请求头
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user