Compare commits
75 Commits
b3258e76df
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 2dec9e5c54 | |||
| c524e8f928 | |||
| 6b00045f4e | |||
| e719d3c8f1 | |||
| 6908b9653b | |||
| d8e64ef0e9 | |||
| fb9f6d1d00 | |||
| e4c96da8a9 | |||
| 1195e119c6 | |||
| 4eeb14e844 | |||
| 0d30ed9a0f | |||
| cd0b3e8fc1 | |||
| c04a13bf8a | |||
| 5513f0c13d | |||
| 598e2acb7e | |||
| 4870d29638 | |||
| 8600a39b6c | |||
| 407d008e19 | |||
| a2751eab31 | |||
| 5655fc5560 | |||
| 49b47a1ae0 | |||
| bcf82d42bc | |||
| 394025c8ea | |||
| 34bd749741 | |||
| 290f299e22 | |||
| 859dec8ada | |||
| 993c0a72d6 | |||
| c9c3a84b33 | |||
| 6de7a2d2e1 | |||
| 6181923d8d | |||
| 235efb0e62 | |||
| 6b1af27ea2 | |||
| 32f48777f3 | |||
| bc7a7c6e81 | |||
| 3cd0458c2c | |||
| 8eea30ea11 | |||
| 9e33e570af | |||
| 7653385838 | |||
| 2c401f7ae6 | |||
| a9972360c2 | |||
| b00fa4dcee | |||
| 92525b39c3 | |||
| 38a2555c7b | |||
| 9622d44aac | |||
| 155244433f | |||
| 2c043c6cf7 | |||
| f5c82b6980 | |||
| 9105a36097 | |||
| f1ee646ca4 | |||
| b9b487c591 | |||
| 4c62c071fb | |||
| b2e9dd8b7f | |||
| d143c5f3df | |||
| 4eebdfb8db | |||
| b517946585 | |||
| 4ddae6be74 | |||
| 195762ff97 | |||
| bcf5ca89e5 | |||
| 365943e4c4 | |||
| 4c6b49099d | |||
| 4c78ab6cc8 | |||
| 52007c9461 | |||
| 086dd1fed7 | |||
| 1d7e839b49 | |||
| fa7babf13b | |||
| 280099b89c | |||
| 0a92a25451 | |||
| 8c075194e5 | |||
| 53e477d383 | |||
| 1522c87c74 | |||
| e0d05c9869 | |||
| 5b401e29cb | |||
| 65ac9f740a | |||
| 58ebcaa299 | |||
| 5b765c8b5e |
8
.claude/settings.json
Normal file
8
.claude/settings.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"tdesign-mcp-server": {
|
||||||
|
"command": "bunx",
|
||||||
|
"args": ["tdesign-mcp-server@latest"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
7
.editorconfig
Normal file
7
.editorconfig
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
root = true
|
||||||
|
|
||||||
|
[*]
|
||||||
|
end_of_line = lf
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
insert_final_newline = true
|
||||||
|
charset = utf-8
|
||||||
9
.gitattributes
vendored
Normal file
9
.gitattributes
vendored
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
* text=auto eol=lf
|
||||||
|
assets/*.png filter=lfs diff=lfs merge=lfs -text
|
||||||
|
assets/**/*.png filter=lfs diff=lfs merge=lfs -text
|
||||||
|
assets/*.icns filter=lfs diff=lfs merge=lfs -text
|
||||||
|
assets/**/*.icns filter=lfs diff=lfs merge=lfs -text
|
||||||
|
assets/*.ico filter=lfs diff=lfs merge=lfs -text
|
||||||
|
assets/**/*.ico filter=lfs diff=lfs merge=lfs -text
|
||||||
|
frontend/public/*.png filter=lfs diff=lfs merge=lfs -text
|
||||||
|
frontend/public/**/*.png filter=lfs diff=lfs merge=lfs -text
|
||||||
14
.github/workflows/ci.yml
vendored
Normal file
14
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [dev, master]
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check:
|
||||||
|
name: Check
|
||||||
|
uses: ./.github/workflows/test.yml
|
||||||
313
.github/workflows/release.yml
vendored
Normal file
313
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
name: Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v*.*.*'
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
prepare:
|
||||||
|
name: Prepare Release
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
outputs:
|
||||||
|
version: ${{ steps.version.outputs.version }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
with:
|
||||||
|
lfs: true
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v6
|
||||||
|
with:
|
||||||
|
go-version-file: go.work
|
||||||
|
cache-dependency-path: |
|
||||||
|
backend/go.sum
|
||||||
|
versionctl/go.sum
|
||||||
|
|
||||||
|
- name: Verify tag and VERSION
|
||||||
|
id: version
|
||||||
|
run: |
|
||||||
|
version=$(go run ./versionctl print)
|
||||||
|
go run ./versionctl verify-tag "${GITHUB_REF_NAME}"
|
||||||
|
printf 'version=%s\n' "$version" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
test-gate:
|
||||||
|
name: Test Gate
|
||||||
|
needs: prepare
|
||||||
|
uses: ./.github/workflows/test.yml
|
||||||
|
with:
|
||||||
|
full: true
|
||||||
|
|
||||||
|
build-web:
|
||||||
|
name: Build Web Asset
|
||||||
|
needs: [prepare, test-gate]
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
with:
|
||||||
|
lfs: true
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v6
|
||||||
|
with:
|
||||||
|
go-version-file: go.work
|
||||||
|
cache-dependency-path: |
|
||||||
|
backend/go.sum
|
||||||
|
versionctl/go.sum
|
||||||
|
|
||||||
|
- name: Setup Bun
|
||||||
|
uses: oven-sh/setup-bun@v2
|
||||||
|
|
||||||
|
- name: Preflight web release toolchain
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
command -v go
|
||||||
|
go version
|
||||||
|
command -v bun
|
||||||
|
bun --version
|
||||||
|
make release-assets-check
|
||||||
|
|
||||||
|
- name: Build web release asset
|
||||||
|
run: make release-assets-web
|
||||||
|
|
||||||
|
- name: Upload web release asset
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: release-web
|
||||||
|
path: build/release/*
|
||||||
|
if-no-files-found: error
|
||||||
|
|
||||||
|
build-linux:
|
||||||
|
name: Build Linux ${{ matrix.arch }} Assets
|
||||||
|
needs: [prepare, test-gate]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- arch: amd64
|
||||||
|
runner: ubuntu-latest
|
||||||
|
- arch: arm64
|
||||||
|
runner: ubuntu-24.04-arm
|
||||||
|
runs-on: ${{ matrix.runner }}
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
with:
|
||||||
|
lfs: true
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v6
|
||||||
|
with:
|
||||||
|
go-version-file: go.work
|
||||||
|
cache-dependency-path: |
|
||||||
|
backend/go.sum
|
||||||
|
versionctl/go.sum
|
||||||
|
|
||||||
|
- name: Setup Bun
|
||||||
|
uses: oven-sh/setup-bun@v2
|
||||||
|
|
||||||
|
- name: Install Linux desktop and package dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y curl file libayatana-appindicator3-dev libgtk-3-dev rpm
|
||||||
|
|
||||||
|
- name: Preflight Linux release toolchain
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
printf 'runner arch: %s\n' "$(uname -m)"
|
||||||
|
command -v go
|
||||||
|
go version
|
||||||
|
command -v bun
|
||||||
|
bun --version
|
||||||
|
command -v gcc
|
||||||
|
gcc --version
|
||||||
|
command -v pkg-config
|
||||||
|
pkg-config --modversion ayatana-appindicator3-0.1
|
||||||
|
pkg-config --modversion gtk+-3.0
|
||||||
|
command -v curl
|
||||||
|
command -v dpkg-deb
|
||||||
|
dpkg-deb --version
|
||||||
|
command -v rpmbuild
|
||||||
|
rpmbuild --version
|
||||||
|
make release-assets-check
|
||||||
|
|
||||||
|
- name: Build Linux release assets
|
||||||
|
run: make release-assets-linux
|
||||||
|
|
||||||
|
- name: Upload Linux release assets
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: release-linux-${{ matrix.arch }}
|
||||||
|
path: build/release/*
|
||||||
|
if-no-files-found: error
|
||||||
|
|
||||||
|
build-windows:
|
||||||
|
name: Build Windows ${{ matrix.arch }} Assets
|
||||||
|
needs: [prepare, test-gate]
|
||||||
|
runs-on: ${{ matrix.runner }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- arch: amd64
|
||||||
|
runner: windows-latest
|
||||||
|
msystem: MINGW64
|
||||||
|
cc: gcc
|
||||||
|
cxx: g++
|
||||||
|
packages: >-
|
||||||
|
make
|
||||||
|
mingw-w64-x86_64-gcc
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
with:
|
||||||
|
lfs: true
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v6
|
||||||
|
with:
|
||||||
|
go-version-file: go.work
|
||||||
|
cache-dependency-path: |
|
||||||
|
backend/go.sum
|
||||||
|
versionctl/go.sum
|
||||||
|
|
||||||
|
- name: Setup Bun
|
||||||
|
uses: oven-sh/setup-bun@v2
|
||||||
|
|
||||||
|
- name: Setup MSYS2 toolchain
|
||||||
|
uses: msys2/setup-msys2@v2
|
||||||
|
with:
|
||||||
|
msystem: ${{ matrix.msystem }}
|
||||||
|
path-type: inherit
|
||||||
|
update: true
|
||||||
|
install: ${{ matrix.packages }}
|
||||||
|
|
||||||
|
- name: Preflight Windows release toolchain
|
||||||
|
shell: msys2 {0}
|
||||||
|
env:
|
||||||
|
CC: ${{ matrix.cc }}
|
||||||
|
CXX: ${{ matrix.cxx }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
command -v go
|
||||||
|
go version
|
||||||
|
command -v bun
|
||||||
|
bun --version
|
||||||
|
command -v make
|
||||||
|
make --version
|
||||||
|
command -v "$CC"
|
||||||
|
"$CC" --version
|
||||||
|
command -v "$CXX"
|
||||||
|
"$CXX" --version
|
||||||
|
command -v windres
|
||||||
|
windres --version
|
||||||
|
if command -v powershell.exe >/dev/null 2>&1; then
|
||||||
|
powershell.exe -NoProfile -Command '$PSVersionTable.PSVersion.ToString()'
|
||||||
|
else
|
||||||
|
command -v powershell
|
||||||
|
powershell -NoProfile -Command '$PSVersionTable.PSVersion.ToString()'
|
||||||
|
fi
|
||||||
|
make release-assets-check
|
||||||
|
|
||||||
|
- name: Build Windows release assets
|
||||||
|
shell: msys2 {0}
|
||||||
|
env:
|
||||||
|
CC: ${{ matrix.cc }}
|
||||||
|
CXX: ${{ matrix.cxx }}
|
||||||
|
run: make release-assets-windows
|
||||||
|
|
||||||
|
- name: Upload Windows release assets
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: release-windows-${{ matrix.arch }}
|
||||||
|
path: build/release/*
|
||||||
|
if-no-files-found: error
|
||||||
|
|
||||||
|
build-macos:
|
||||||
|
name: Build macOS Assets
|
||||||
|
needs: [prepare, test-gate]
|
||||||
|
runs-on: macos-15
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
with:
|
||||||
|
lfs: true
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v6
|
||||||
|
with:
|
||||||
|
go-version-file: go.work
|
||||||
|
cache-dependency-path: |
|
||||||
|
backend/go.sum
|
||||||
|
versionctl/go.sum
|
||||||
|
|
||||||
|
- name: Setup Bun
|
||||||
|
uses: oven-sh/setup-bun@v2
|
||||||
|
|
||||||
|
- name: Preflight macOS release toolchain
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
printf 'runner arch: %s\n' "$(uname -m)"
|
||||||
|
command -v go
|
||||||
|
go version
|
||||||
|
command -v bun
|
||||||
|
bun --version
|
||||||
|
command -v ditto
|
||||||
|
command -v hdiutil
|
||||||
|
xcrun --find lipo
|
||||||
|
xcrun --find vtool
|
||||||
|
make release-assets-check
|
||||||
|
|
||||||
|
- name: Build macOS release assets
|
||||||
|
run: make release-assets-macos
|
||||||
|
|
||||||
|
- name: Upload macOS release assets
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: release-macos
|
||||||
|
path: build/release/*
|
||||||
|
if-no-files-found: error
|
||||||
|
|
||||||
|
draft-release:
|
||||||
|
name: Create Draft Release
|
||||||
|
needs: [prepare, build-web, build-linux, build-windows, build-macos]
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
|
- name: Download release assets
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
pattern: release-*
|
||||||
|
merge-multiple: true
|
||||||
|
path: dist
|
||||||
|
|
||||||
|
- name: Generate checksums
|
||||||
|
run: make release-assets-checksums RELEASE_DIR=dist
|
||||||
|
|
||||||
|
- name: Publish draft release
|
||||||
|
uses: softprops/action-gh-release@v2
|
||||||
|
with:
|
||||||
|
name: v${{ needs.prepare.outputs.version }}
|
||||||
|
tag_name: ${{ github.ref_name }}
|
||||||
|
draft: true
|
||||||
|
files: |
|
||||||
|
dist/*
|
||||||
116
.github/workflows/test.yml
vendored
Normal file
116
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
name: Test (Full)
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
full:
|
||||||
|
description: "Run full test suite including MySQL and E2E"
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
type: boolean
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check:
|
||||||
|
name: Check (${{ matrix.os }})
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
with:
|
||||||
|
lfs: true
|
||||||
|
|
||||||
|
- name: Install Linux system dependencies
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: sudo apt-get update && sudo apt-get install -y libayatana-appindicator3-dev
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v6
|
||||||
|
with:
|
||||||
|
go-version-file: go.work
|
||||||
|
cache-dependency-path: |
|
||||||
|
backend/go.sum
|
||||||
|
versionctl/go.sum
|
||||||
|
|
||||||
|
- name: Setup Bun
|
||||||
|
uses: oven-sh/setup-bun@v2
|
||||||
|
|
||||||
|
- name: Lint
|
||||||
|
run: make lint
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: make test
|
||||||
|
|
||||||
|
mysql:
|
||||||
|
name: MySQL Tests
|
||||||
|
if: inputs.full
|
||||||
|
needs: check
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
services:
|
||||||
|
mysql:
|
||||||
|
image: mysql:8.0
|
||||||
|
env:
|
||||||
|
MYSQL_ROOT_PASSWORD: testpass
|
||||||
|
MYSQL_DATABASE: nex_test
|
||||||
|
MYSQL_USER: nex_test
|
||||||
|
MYSQL_PASSWORD: testpass
|
||||||
|
ports:
|
||||||
|
- 13306:3306
|
||||||
|
options: >-
|
||||||
|
--health-cmd="mysqladmin ping -h localhost -u root -ptestpass"
|
||||||
|
--health-interval=3s
|
||||||
|
--health-timeout=5s
|
||||||
|
--health-retries=10
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
with:
|
||||||
|
lfs: true
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v6
|
||||||
|
with:
|
||||||
|
go-version-file: go.work
|
||||||
|
cache-dependency-path: |
|
||||||
|
backend/go.sum
|
||||||
|
versionctl/go.sum
|
||||||
|
|
||||||
|
- name: MySQL tests
|
||||||
|
run: cd backend && go test -tags=mysql ./tests/mysql/... -v -count=1
|
||||||
|
|
||||||
|
e2e:
|
||||||
|
name: E2E Tests
|
||||||
|
if: inputs.full
|
||||||
|
needs: check
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
with:
|
||||||
|
lfs: true
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v6
|
||||||
|
with:
|
||||||
|
go-version-file: go.work
|
||||||
|
cache-dependency-path: |
|
||||||
|
backend/go.sum
|
||||||
|
versionctl/go.sum
|
||||||
|
|
||||||
|
- name: Setup Bun
|
||||||
|
uses: oven-sh/setup-bun@v2
|
||||||
|
|
||||||
|
- name: Install Playwright browsers
|
||||||
|
run: cd frontend && bunx playwright install --with-deps chromium
|
||||||
|
|
||||||
|
- name: E2E tests
|
||||||
|
run: cd frontend && bun run test:e2e
|
||||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -399,15 +399,21 @@ env/
|
|||||||
cython_debug/
|
cython_debug/
|
||||||
|
|
||||||
# Custom
|
# Custom
|
||||||
.claude
|
.claude/*
|
||||||
|
!.claude/settings.json
|
||||||
.opencode
|
.opencode
|
||||||
|
.codex
|
||||||
openspec/changes/archive
|
openspec/changes/archive
|
||||||
temp
|
temp
|
||||||
.agents
|
.agents
|
||||||
skills-lock.json
|
skills-lock.json
|
||||||
.worktrees
|
.worktrees
|
||||||
!scripts/build/
|
!scripts/build/
|
||||||
|
backend/bin
|
||||||
|
backend/server
|
||||||
|
backend/desktop
|
||||||
|
|
||||||
# Embedfs generated
|
# Embedfs generated
|
||||||
embedfs/assets/
|
embedfs/assets/
|
||||||
embedfs/frontend-dist/
|
embedfs/frontend-dist/
|
||||||
|
backend/cmd/desktop/rsrc_windows_*.syso
|
||||||
|
|||||||
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"files.eol": "\n"
|
||||||
|
}
|
||||||
184
LICENSE
Normal file
184
LICENSE
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction, and
|
||||||
|
distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by the copyright
|
||||||
|
owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all other entities
|
||||||
|
that control, are controlled by, or are under common control with that entity.
|
||||||
|
For the purposes of this definition, "control" means (i) the power, direct or
|
||||||
|
indirect, to cause the direction or management of such entity, whether by
|
||||||
|
contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity exercising
|
||||||
|
permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications, including
|
||||||
|
but not limited to software source code, documentation source, and configuration
|
||||||
|
files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical transformation or
|
||||||
|
translation of a Source form, including but not limited to compiled object code,
|
||||||
|
generated documentation, and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or Object form,
|
||||||
|
made available under the License, as indicated by a copyright notice that is
|
||||||
|
included in or attached to the work (an example is provided in the Appendix
|
||||||
|
below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object form, that
|
||||||
|
is based on (or derived from) the Work and for which the editorial revisions,
|
||||||
|
annotations, elaborations, or other modifications represent, as a whole, an
|
||||||
|
original work of authorship. For the purposes of this License, Derivative Works
|
||||||
|
shall not include works that remain separable from, or merely link (or bind by
|
||||||
|
name) to the interfaces of, the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including the original version
|
||||||
|
of the Work and any modifications or additions to that Work or Derivative Works
|
||||||
|
thereof, that is intentionally submitted to Licensor for inclusion in the Work
|
||||||
|
by the copyright owner or by an individual or Legal Entity authorized to submit
|
||||||
|
on behalf of the copyright owner. For the purposes of this definition,
|
||||||
|
"submitted" means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems, and
|
||||||
|
issue tracking systems that are managed by, or on behalf of, the Licensor for
|
||||||
|
the purpose of discussing and improving the Work, but excluding communication
|
||||||
|
that is conspicuously marked or otherwise designated in writing by the copyright
|
||||||
|
owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf
|
||||||
|
of whom a Contribution has been received by Licensor and subsequently
|
||||||
|
incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of this
|
||||||
|
License, each Contributor hereby grants to You a perpetual, worldwide,
|
||||||
|
non-exclusive, no-charge, royalty-free, irrevocable copyright license to
|
||||||
|
reproduce, prepare Derivative Works of, publicly display, publicly perform,
|
||||||
|
sublicense, and distribute the Work and such Derivative Works in Source or
|
||||||
|
Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of this License,
|
||||||
|
each Contributor hereby grants to You a perpetual, worldwide, non-exclusive,
|
||||||
|
no-charge, royalty-free, irrevocable (except as stated in this section) patent
|
||||||
|
license to make, have made, use, offer to sell, sell, import, and otherwise
|
||||||
|
transfer the Work, where such license applies only to those patent claims
|
||||||
|
licensable by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s) with the Work
|
||||||
|
to which such Contribution(s) was submitted. If You institute patent litigation
|
||||||
|
against any entity (including a cross-claim or counterclaim in a lawsuit)
|
||||||
|
alleging that the Work or a Contribution incorporated within the Work
|
||||||
|
constitutes direct or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate as of the date
|
||||||
|
such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the Work or
|
||||||
|
Derivative Works thereof in any medium, with or without modifications, and in
|
||||||
|
Source or Object form, provided that You meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or Derivative Works a copy of
|
||||||
|
this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices stating that
|
||||||
|
You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works that You
|
||||||
|
distribute, all copyright, patent, trademark, and attribution notices from the
|
||||||
|
Source form of the Work, excluding those notices that do not pertain to any part
|
||||||
|
of the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its distribution, then
|
||||||
|
any Derivative Works that You distribute must include a readable copy of the
|
||||||
|
attribution notices contained within such NOTICE file, excluding those notices
|
||||||
|
that do not pertain to any part of the Derivative Works, in at least one of the
|
||||||
|
following places: within a NOTICE text file distributed as part of the
|
||||||
|
Derivative Works; within the Source form or documentation, if provided along
|
||||||
|
with the Derivative Works; or, within a display generated by the Derivative
|
||||||
|
Works, if and wherever such third-party notices normally appear. The contents of
|
||||||
|
the NOTICE file are for informational purposes only and do not modify the
|
||||||
|
License. You may add Your own attribution notices within Derivative Works that
|
||||||
|
You distribute, alongside or as an addendum to the NOTICE text from the Work,
|
||||||
|
provided that such additional attribution notices cannot be construed as
|
||||||
|
modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and may provide
|
||||||
|
additional or different license terms and conditions for use, reproduction, or
|
||||||
|
distribution of Your modifications, or for any such Derivative Works as a whole,
|
||||||
|
provided Your use, reproduction, and distribution of the Work otherwise complies
|
||||||
|
with the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise, any
|
||||||
|
Contribution intentionally submitted for inclusion in the Work by You to the
|
||||||
|
Licensor shall be under the terms and conditions of this License, without any
|
||||||
|
additional terms or conditions. Notwithstanding the above, nothing herein shall
|
||||||
|
supersede or modify the terms of any separate license agreement you may have
|
||||||
|
executed with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade names,
|
||||||
|
trademarks, service marks, or product names of the Licensor, except as required
|
||||||
|
for reasonable and customary use in describing the origin of the Work and
|
||||||
|
reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in
|
||||||
|
writing, Licensor provides the Work (and each Contributor provides its
|
||||||
|
Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
KIND, either express or implied, including, without limitation, any warranties
|
||||||
|
or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any risks
|
||||||
|
associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory, whether in
|
||||||
|
tort (including negligence), contract, or otherwise, unless required by
|
||||||
|
applicable law (such as deliberate and grossly negligent acts) or agreed to in
|
||||||
|
writing, shall any Contributor be liable to You for damages, including any
|
||||||
|
direct, indirect, special, incidental, or consequential damages of any character
|
||||||
|
arising as a result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill, work stoppage,
|
||||||
|
computer failure or malfunction, or any and all other commercial damages or
|
||||||
|
losses), even if such Contributor has been advised of the possibility of such
|
||||||
|
damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing the Work or
|
||||||
|
Derivative Works thereof, You may choose to offer, and charge a fee for,
|
||||||
|
acceptance of support, warranty, indemnity, or other liability obligations
|
||||||
|
and/or rights consistent with this License. However, in accepting such
|
||||||
|
obligations, You may act only on Your own behalf and on Your sole
|
||||||
|
responsibility, not on behalf of any other Contributor, and only if You agree to
|
||||||
|
indemnify, defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason of your
|
||||||
|
accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following boilerplate
|
||||||
|
notice, with the fields enclosed by brackets "[]" replaced with your own
|
||||||
|
identifying information. (Don't include the brackets!) The text should be
|
||||||
|
enclosed in the appropriate comment syntax for the file format. We also
|
||||||
|
recommend that a file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier identification within
|
||||||
|
third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
580
Makefile
580
Makefile
@@ -1,112 +1,520 @@
|
|||||||
.PHONY: all clean \
|
.PHONY: \
|
||||||
backend-build backend-run backend-test backend-test-unit backend-test-integration backend-test-coverage \
|
lint test clean hooks-install hooks-check hooks-test \
|
||||||
backend-lint backend-deps backend-generate \
|
version-sync version-check version-bump \
|
||||||
backend-migrate-up backend-migrate-down backend-migrate-status backend-migrate-create \
|
server-run server-build server-lint server-test server-clean \
|
||||||
frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint \
|
desktop-build-mac desktop-build-win desktop-build-linux \
|
||||||
desktop-mac desktop-win desktop-linux package-macos
|
desktop-lint desktop-test desktop-clean \
|
||||||
|
release-assets-check release-assets-web release-assets-linux release-assets-windows release-assets-macos release-assets-checksums \
|
||||||
|
release-assets-server-linux release-assets-server-windows release-assets-server-macos \
|
||||||
|
release-assets-desktop-linux release-assets-desktop-windows release-assets-desktop-macos \
|
||||||
|
_backend-lint _backend-test _backend-clean _backend-build \
|
||||||
|
_versionctl-lint _versionctl-test \
|
||||||
|
_frontend-install _frontend-build _frontend-check _frontend-test _frontend-dev _frontend-clean \
|
||||||
|
_hooks-pre-commit _check-clean-worktree \
|
||||||
|
_desktop-test _desktop-clean _desktop-prepare-frontend _desktop-prepare-embedfs _desktop-prepare-windows-resource \
|
||||||
|
_server-run-backend _server-run-frontend \
|
||||||
|
_check-linux-target-arch _check-windows-target-arch _ensure-appimagetool \
|
||||||
|
_package-linux-tar _package-linux-appimage _package-linux-deb _package-linux-rpm \
|
||||||
|
_package-macos-zip _package-macos-dmg
|
||||||
|
|
||||||
|
# Delay shell lookups until a target needs them, then cache the result for this make run.
|
||||||
|
lazy_shell = $(or $($(1)),$(eval $(1) := $(shell $(2)))$($(1)))
|
||||||
|
|
||||||
|
VERSION = $(call lazy_shell,_VERSION,go run ./versionctl print)
|
||||||
|
GIT_COMMIT ?= $(call lazy_shell,_GIT_COMMIT,git rev-parse --short HEAD 2>/dev/null || printf 'unknown')
|
||||||
|
BUILD_TIME ?= $(call lazy_shell,_BUILD_TIME,date -u +"%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
TARGET_ARCH ?= $(call lazy_shell,_TARGET_ARCH,go env GOARCH)
|
||||||
|
GO_LDFLAGS = -X nex/backend/pkg/buildinfo.version=$(VERSION) -X nex/backend/pkg/buildinfo.commit=$(GIT_COMMIT) -X nex/backend/pkg/buildinfo.buildTime=$(BUILD_TIME)
|
||||||
|
GO_LDFLAGS_WIN = $(GO_LDFLAGS) -H=windowsgui
|
||||||
|
RELEASE_DIR ?= build/release
|
||||||
|
LINUX_DESKTOP_BINARY = build/nex-linux-$(TARGET_ARCH)
|
||||||
|
WINDOWS_DESKTOP_BINARY = build/nex-win-$(TARGET_ARCH).exe
|
||||||
|
WINDOWS_SERVER_BINARY = build/nex-server-windows-$(TARGET_ARCH).exe
|
||||||
|
WINDRES ?= windres
|
||||||
|
|
||||||
|
ifeq ($(TARGET_ARCH),arm64)
|
||||||
|
APPIMAGE_ARCH := aarch64
|
||||||
|
DEB_ARCH := arm64
|
||||||
|
RPM_ARCH := aarch64
|
||||||
|
else
|
||||||
|
APPIMAGE_ARCH := x86_64
|
||||||
|
DEB_ARCH := amd64
|
||||||
|
RPM_ARCH := x86_64
|
||||||
|
endif
|
||||||
|
|
||||||
|
WINDOWS_WINDRES_FORMAT_BFD := pe-x86-64
|
||||||
|
WINDOWS_WINDRES_FORMAT_LLVM := x86_64-w64-mingw32
|
||||||
|
WINDOWS_RESOURCE := rsrc_windows_amd64.syso
|
||||||
|
|
||||||
|
APPIMAGETOOL_PATH := build/tools/appimagetool-$(APPIMAGE_ARCH).AppImage
|
||||||
|
APPIMAGETOOL_URL ?= https://github.com/AppImage/AppImageKit/releases/download/continuous/appimagetool-$(APPIMAGE_ARCH).AppImage
|
||||||
|
APPIMAGETOOL ?= $(APPIMAGETOOL_PATH)
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# 后端
|
# 全局命令
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
all: backend-build
|
lint: _backend-lint _frontend-check _versionctl-lint
|
||||||
|
@printf 'Lint complete\n'
|
||||||
|
|
||||||
backend-build:
|
test: _backend-test _frontend-test _desktop-test _versionctl-test
|
||||||
cd backend && go build -o bin/server ./cmd/server
|
@printf 'All tests passed\n'
|
||||||
|
|
||||||
backend-run:
|
clean: _backend-clean _frontend-clean _desktop-clean
|
||||||
cd backend && go run ./cmd/server
|
@printf 'Clean complete\n'
|
||||||
|
|
||||||
backend-test:
|
|
||||||
cd backend && go test ./... -v
|
|
||||||
|
|
||||||
backend-test-unit:
|
|
||||||
cd backend && go test ./internal/... ./pkg/... -v
|
|
||||||
|
|
||||||
backend-test-integration:
|
|
||||||
cd backend && go test ./tests/... -v
|
|
||||||
|
|
||||||
backend-test-coverage:
|
|
||||||
cd backend && go test ./... -coverprofile=coverage.out
|
|
||||||
cd backend && go tool cover -html=coverage.out -o coverage.html
|
|
||||||
@echo "Coverage report generated: backend/coverage.html"
|
|
||||||
|
|
||||||
backend-lint:
|
|
||||||
cd backend && go tool golangci-lint run ./...
|
|
||||||
|
|
||||||
backend-deps:
|
|
||||||
cd backend && go mod tidy
|
|
||||||
|
|
||||||
backend-generate:
|
|
||||||
cd backend && go generate ./...
|
|
||||||
|
|
||||||
backend-migrate-up:
|
|
||||||
cd backend && goose -dir migrations sqlite3 $(DB_PATH) up
|
|
||||||
|
|
||||||
backend-migrate-down:
|
|
||||||
cd backend && goose -dir migrations sqlite3 $(DB_PATH) down
|
|
||||||
|
|
||||||
backend-migrate-status:
|
|
||||||
cd backend && goose -dir migrations sqlite3 $(DB_PATH) status
|
|
||||||
|
|
||||||
backend-migrate-create:
|
|
||||||
@read -p "Migration name: " name; \
|
|
||||||
cd backend && goose -dir migrations create $$name sql
|
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# 前端
|
# Git hooks
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
frontend-build:
|
hooks-install:
|
||||||
cd frontend && bun install && bun run build
|
@hooks_dir=$$(git rev-parse --git-path hooks); \
|
||||||
|
mkdir -p "$$hooks_dir"; \
|
||||||
|
for hook in pre-commit commit-msg prepare-commit-msg; do \
|
||||||
|
src="scripts/git-hooks/$$hook"; \
|
||||||
|
if [ ! -f "$$src" ]; then \
|
||||||
|
printf 'ERROR: source hook not found: %s\n' "$$src" >&2; \
|
||||||
|
exit 1; \
|
||||||
|
fi; \
|
||||||
|
cp "$$src" "$$hooks_dir/$$hook"; \
|
||||||
|
chmod +x "$$hooks_dir/$$hook"; \
|
||||||
|
done; \
|
||||||
|
printf 'Installed Git hooks to %s\n' "$$hooks_dir"
|
||||||
|
|
||||||
frontend-dev:
|
hooks-check:
|
||||||
cd frontend && bun dev
|
@hooks_dir=$$(git rev-parse --git-path hooks); \
|
||||||
|
status=0; \
|
||||||
|
for hook in pre-commit commit-msg prepare-commit-msg; do \
|
||||||
|
if [ -x "$$hooks_dir/$$hook" ]; then \
|
||||||
|
printf 'OK: %s\n' "$$hook"; \
|
||||||
|
else \
|
||||||
|
printf 'MISSING: %s (%s/%s)\n' "$$hook" "$$hooks_dir" "$$hook"; \
|
||||||
|
status=1; \
|
||||||
|
fi; \
|
||||||
|
done; \
|
||||||
|
exit $$status
|
||||||
|
|
||||||
frontend-test:
|
hooks-test:
|
||||||
cd frontend && bun run test
|
@scripts/git-hooks/test-hooks.sh
|
||||||
|
|
||||||
frontend-test-watch:
|
_hooks-pre-commit:
|
||||||
cd frontend && bun run test:watch
|
@set -ef; \
|
||||||
|
staged_files=$$(git diff --cached --name-only --diff-filter=ACM); \
|
||||||
frontend-test-coverage:
|
if [ -z "$$staged_files" ]; then \
|
||||||
cd frontend && bun run test:coverage
|
printf 'No staged files to check\n'; \
|
||||||
|
exit 0; \
|
||||||
frontend-test-e2e:
|
fi; \
|
||||||
cd frontend && bun run test:e2e
|
run_backend_lint=; \
|
||||||
|
run_versionctl_lint=; \
|
||||||
frontend-lint:
|
run_frontend_check=; \
|
||||||
cd frontend && bun run lint
|
lfs_patterns=$$(grep 'filter=lfs' .gitattributes 2>/dev/null | awk '{print $$1}' || true); \
|
||||||
|
for file in $$staged_files; do \
|
||||||
|
[ -n "$$file" ] || continue; \
|
||||||
|
if git show ":$$file" 2>/dev/null | grep -Eq '^(<<<<<<<|=======|>>>>>>>)'; then \
|
||||||
|
printf 'Found conflict markers in staged file: %s\n' "$$file" >&2; \
|
||||||
|
printf 'Resolve conflict markers before committing.\n' >&2; \
|
||||||
|
exit 1; \
|
||||||
|
fi; \
|
||||||
|
size=$$(git cat-file -s ":$$file" 2>/dev/null || printf '0'); \
|
||||||
|
if [ "$$size" -gt 512000 ] 2>/dev/null; then \
|
||||||
|
if git show ":$$file" 2>/dev/null | LC_ALL=C grep -Iq .; then \
|
||||||
|
printf 'Warning: large staged text file (%s bytes): %s\n' "$$size" "$$file" >&2; \
|
||||||
|
fi; \
|
||||||
|
fi; \
|
||||||
|
if [ -n "$$lfs_patterns" ]; then \
|
||||||
|
for lfs_pat in $$lfs_patterns; do \
|
||||||
|
case "$$file" in $$lfs_pat) \
|
||||||
|
content=$$(git show ":$$file" 2>/dev/null | head -1); \
|
||||||
|
case "$$content" in \
|
||||||
|
"version https://git-lfs.github.com/spec/v1"*) ;; \
|
||||||
|
*) \
|
||||||
|
printf 'LFS-tracked file not using LFS pointer: %s\n' "$$file" >&2; \
|
||||||
|
printf 'Run "git lfs install" and re-add this file.\n' >&2; \
|
||||||
|
exit 1; \
|
||||||
|
;; \
|
||||||
|
esac; \
|
||||||
|
break; \
|
||||||
|
;; \
|
||||||
|
esac; \
|
||||||
|
done; \
|
||||||
|
fi; \
|
||||||
|
case "$$file" in \
|
||||||
|
backend/*.go) run_backend_lint=1 ;; \
|
||||||
|
versionctl/*.go) run_versionctl_lint=1 ;; \
|
||||||
|
frontend/*.ts|frontend/*.tsx|frontend/*.scss) run_frontend_check=1 ;; \
|
||||||
|
esac; \
|
||||||
|
done; \
|
||||||
|
if [ -n "$$run_backend_lint" ]; then \
|
||||||
|
printf 'Running backend lint...\n'; \
|
||||||
|
$(MAKE) _backend-lint; \
|
||||||
|
fi; \
|
||||||
|
if [ -n "$$run_versionctl_lint" ]; then \
|
||||||
|
printf 'Running versionctl lint...\n'; \
|
||||||
|
$(MAKE) _versionctl-lint; \
|
||||||
|
fi; \
|
||||||
|
if [ -n "$$run_frontend_check" ]; then \
|
||||||
|
printf 'Running frontend check...\n'; \
|
||||||
|
$(MAKE) _frontend-check; \
|
||||||
|
fi; \
|
||||||
|
printf 'Pre-commit checks passed\n'
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# 桌面应用
|
# 版本管理
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
frontend-build-desktop:
|
version-sync:
|
||||||
cd frontend && cp .env.desktop .env.production.local && bun install && bun run build && rm -f .env.production.local
|
go run ./versionctl sync
|
||||||
|
|
||||||
embedfs-prepare:
|
version-check:
|
||||||
|
go run ./versionctl check
|
||||||
|
|
||||||
|
version-bump: BUMP ?= patch
|
||||||
|
version-bump: lint test _check-clean-worktree
|
||||||
|
@set -e; \
|
||||||
|
bump_arg="$(if $(SET_VERSION),$(SET_VERSION),$(BUMP))"; \
|
||||||
|
new_version=$$(go run ./versionctl bump "$$bump_arg"); \
|
||||||
|
git add VERSION frontend/; \
|
||||||
|
git commit -m "chore: 版本升迁 v$$new_version"; \
|
||||||
|
git tag "v$$new_version"; \
|
||||||
|
printf '版本升迁完成: v%s\n' "$$new_version"
|
||||||
|
|
||||||
|
_check-clean-worktree:
|
||||||
|
@if [ -n "$$(git status --porcelain)" ]; then \
|
||||||
|
printf '工作区不干净,请先提交或清理改动后再执行版本升迁。\n' >&2; \
|
||||||
|
git status --short; \
|
||||||
|
exit 1; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ============================================
|
||||||
|
# Server 模式
|
||||||
|
# ============================================
|
||||||
|
|
||||||
|
server-run:
|
||||||
|
@$(MAKE) -j2 _server-run-backend _server-run-frontend
|
||||||
|
|
||||||
|
server-build: version-check _backend-build _frontend-build
|
||||||
|
@printf 'Server build complete\n'
|
||||||
|
|
||||||
|
server-lint: _backend-lint _frontend-check
|
||||||
|
@printf 'Server lint complete\n'
|
||||||
|
|
||||||
|
server-test: _backend-test _frontend-test
|
||||||
|
@printf 'Server tests passed\n'
|
||||||
|
|
||||||
|
server-clean: _backend-clean _frontend-clean
|
||||||
|
@printf 'Server artifacts cleaned\n'
|
||||||
|
|
||||||
|
_server-run-backend:
|
||||||
|
@$(MAKE) -C backend run
|
||||||
|
|
||||||
|
_server-run-frontend: _frontend-install
|
||||||
|
cd frontend && bun run dev
|
||||||
|
|
||||||
|
# ============================================
|
||||||
|
# Desktop 模式
|
||||||
|
# ============================================
|
||||||
|
|
||||||
|
desktop-build-mac: version-check _desktop-prepare-frontend _desktop-prepare-embedfs
|
||||||
|
@printf 'Building macOS desktop...\n'
|
||||||
|
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-mac-arm64 ./cmd/desktop
|
||||||
|
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-mac-amd64 ./cmd/desktop
|
||||||
|
lipo -create build/nex-mac-arm64 build/nex-mac-amd64 -output build/nex-mac-universal
|
||||||
|
lipo -info build/nex-mac-universal | grep -q 'x86_64 arm64'
|
||||||
|
rm -f build/nex-mac-arm64 build/nex-mac-amd64
|
||||||
|
@printf 'Packaging macOS app bundle...\n'
|
||||||
|
rm -rf build/Nex.app
|
||||||
|
mkdir -p build/Nex.app/Contents/MacOS build/Nex.app/Contents/Resources
|
||||||
|
cp build/nex-mac-universal build/Nex.app/Contents/MacOS/nex
|
||||||
|
@if [ -f assets/icon.icns ]; then \
|
||||||
|
cp assets/icon.icns build/Nex.app/Contents/Resources/; \
|
||||||
|
else \
|
||||||
|
printf 'Missing assets/icon.icns\n'; \
|
||||||
|
exit 1; \
|
||||||
|
fi
|
||||||
|
@MIN_MACOS_VERSION=$$(vtool -show-build build/nex-mac-universal | awk '/minos / {print $$2; exit}'); \
|
||||||
|
if [ -z "$$MIN_MACOS_VERSION" ]; then \
|
||||||
|
printf 'Unable to read macOS minimum version\n'; \
|
||||||
|
exit 1; \
|
||||||
|
fi; \
|
||||||
|
go run ./versionctl macos-plist "$$MIN_MACOS_VERSION" > build/Nex.app/Contents/Info.plist
|
||||||
|
chmod +x build/Nex.app/Contents/MacOS/nex
|
||||||
|
@printf 'macOS desktop build complete\n'
|
||||||
|
|
||||||
|
desktop-build-win: version-check _desktop-prepare-frontend _desktop-prepare-embedfs _desktop-prepare-windows-resource _check-windows-target-arch
|
||||||
|
@printf 'Building Windows desktop $(TARGET_ARCH)...\n'
|
||||||
|
mkdir -p build
|
||||||
|
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=$(TARGET_ARCH) go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../$(WINDOWS_DESKTOP_BINARY) ./cmd/desktop
|
||||||
|
@printf 'Windows desktop build complete\n'
|
||||||
|
|
||||||
|
desktop-build-linux: version-check _desktop-prepare-frontend _desktop-prepare-embedfs _check-linux-target-arch
|
||||||
|
@printf 'Building Linux desktop $(TARGET_ARCH)...\n'
|
||||||
|
mkdir -p build
|
||||||
|
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=$(TARGET_ARCH) go build -ldflags "$(GO_LDFLAGS)" -o ../$(LINUX_DESKTOP_BINARY) ./cmd/desktop
|
||||||
|
@printf 'Linux desktop build complete\n'
|
||||||
|
|
||||||
|
desktop-lint: _backend-lint _frontend-check
|
||||||
|
@printf 'Desktop lint complete\n'
|
||||||
|
|
||||||
|
desktop-test: _desktop-test
|
||||||
|
@printf 'Desktop tests passed\n'
|
||||||
|
|
||||||
|
desktop-clean: _desktop-clean
|
||||||
|
@printf 'Desktop artifacts cleaned\n'
|
||||||
|
|
||||||
|
_desktop-test:
|
||||||
|
cd backend && go test ./cmd/desktop/... -v
|
||||||
|
|
||||||
|
_desktop-clean:
|
||||||
|
rm -rf build/ embedfs/assets embedfs/frontend-dist backend/cmd/desktop/rsrc_windows_amd64.syso
|
||||||
|
|
||||||
|
_desktop-prepare-frontend: _frontend-install
|
||||||
|
@printf 'Preparing frontend for desktop...\n'
|
||||||
|
cd frontend && cp .env.desktop .env.production.local
|
||||||
|
cd frontend && bun run build
|
||||||
|
rm -f frontend/.env.production.local
|
||||||
|
|
||||||
|
_desktop-prepare-embedfs:
|
||||||
|
@printf 'Preparing embedded filesystem...\n'
|
||||||
rm -rf embedfs/assets embedfs/frontend-dist
|
rm -rf embedfs/assets embedfs/frontend-dist
|
||||||
cp -r assets embedfs/assets
|
cp -r assets embedfs/assets
|
||||||
cp -r frontend/dist embedfs/frontend-dist
|
cp -r frontend/dist embedfs/frontend-dist
|
||||||
|
|
||||||
desktop-mac: frontend-build-desktop embedfs-prepare
|
_desktop-prepare-windows-resource: _check-windows-target-arch
|
||||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -o ../build/nex-mac-arm64 ./cmd/desktop
|
@printf 'Preparing Windows $(TARGET_ARCH) executable icon...\n'
|
||||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -o ../build/nex-mac-amd64 ./cmd/desktop
|
@WINDRES_CMD="$(WINDRES)"; \
|
||||||
|
WINDRES_FMT="$(WINDOWS_WINDRES_FORMAT_BFD)"; \
|
||||||
desktop-win: frontend-build-desktop embedfs-prepare
|
if command -v llvm-windres >/dev/null 2>&1; then \
|
||||||
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -ldflags "-H=windowsgui" -o ../build/nex-win-amd64.exe ./cmd/desktop
|
WINDRES_CMD=llvm-windres; \
|
||||||
|
WINDRES_FMT="$(WINDOWS_WINDRES_FORMAT_LLVM)"; \
|
||||||
desktop-linux: frontend-build-desktop embedfs-prepare
|
elif "$$WINDRES_CMD" --version 2>&1 | grep -qi LLVM; then \
|
||||||
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -o ../build/nex-linux-amd64 ./cmd/desktop
|
WINDRES_FMT="$(WINDOWS_WINDRES_FORMAT_LLVM)"; \
|
||||||
|
fi; \
|
||||||
package-macos:
|
command -v "$$WINDRES_CMD" >/dev/null 2>&1 || { printf 'Missing windres tool: %s\n' "$$WINDRES_CMD"; exit 1; }; \
|
||||||
./scripts/build/package-macos.sh
|
cd backend/cmd/desktop && "$$WINDRES_CMD" -O coff -F "$$WINDRES_FMT" -i icon_windows.rc -o $(WINDOWS_RESOURCE)
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# 清理
|
# 发布资产
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
clean:
|
release-assets-check:
|
||||||
rm -rf backend/bin/ backend/coverage.out backend/coverage.html
|
go run ./versionctl release-assets-check
|
||||||
rm -rf build/
|
@printf 'Release assets check passed\n'
|
||||||
|
|
||||||
|
release-assets-web: version-check release-assets-check _frontend-build
|
||||||
|
rm -rf "$(RELEASE_DIR)"
|
||||||
|
mkdir -p "$(RELEASE_DIR)"
|
||||||
|
asset=$$(go run ./versionctl asset-name web tar.gz); \
|
||||||
|
tar -C frontend -czf "$(RELEASE_DIR)/$$asset" dist
|
||||||
|
|
||||||
|
release-assets-linux: version-check release-assets-check _check-linux-target-arch
|
||||||
|
rm -rf "$(RELEASE_DIR)"
|
||||||
|
mkdir -p "$(RELEASE_DIR)"
|
||||||
|
@$(MAKE) release-assets-server-linux TARGET_ARCH=$(TARGET_ARCH) RELEASE_DIR="$(RELEASE_DIR)"
|
||||||
|
@$(MAKE) release-assets-desktop-linux TARGET_ARCH=$(TARGET_ARCH) RELEASE_DIR="$(RELEASE_DIR)"
|
||||||
|
|
||||||
|
release-assets-windows: version-check release-assets-check _check-windows-target-arch
|
||||||
|
rm -rf "$(RELEASE_DIR)"
|
||||||
|
mkdir -p "$(RELEASE_DIR)"
|
||||||
|
@$(MAKE) release-assets-server-windows TARGET_ARCH=$(TARGET_ARCH) RELEASE_DIR="$(RELEASE_DIR)"
|
||||||
|
@$(MAKE) release-assets-desktop-windows TARGET_ARCH=$(TARGET_ARCH) RELEASE_DIR="$(RELEASE_DIR)"
|
||||||
|
|
||||||
|
release-assets-macos: version-check release-assets-check
|
||||||
|
rm -rf "$(RELEASE_DIR)"
|
||||||
|
mkdir -p "$(RELEASE_DIR)"
|
||||||
|
@$(MAKE) release-assets-server-macos RELEASE_DIR="$(RELEASE_DIR)"
|
||||||
|
@$(MAKE) release-assets-desktop-macos RELEASE_DIR="$(RELEASE_DIR)"
|
||||||
|
|
||||||
|
release-assets-server-linux: version-check _check-linux-target-arch
|
||||||
|
mkdir -p build "$(RELEASE_DIR)"
|
||||||
|
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=$(TARGET_ARCH) go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-linux-$(TARGET_ARCH) ./cmd/server
|
||||||
|
asset=$$(go run ./versionctl asset-name server linux $(TARGET_ARCH) tar.gz); \
|
||||||
|
tar -C build -czf "$(RELEASE_DIR)/$$asset" nex-server-linux-$(TARGET_ARCH)
|
||||||
|
|
||||||
|
release-assets-server-windows: version-check _check-windows-target-arch
|
||||||
|
mkdir -p build "$(RELEASE_DIR)"
|
||||||
|
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=$(TARGET_ARCH) go build -ldflags "$(GO_LDFLAGS)" -o ../$(WINDOWS_SERVER_BINARY) ./cmd/server
|
||||||
|
asset=$$(go run ./versionctl asset-name server windows $(TARGET_ARCH) zip); \
|
||||||
|
if command -v powershell.exe >/dev/null 2>&1; then POWERSHELL=powershell.exe; else POWERSHELL=powershell; fi; \
|
||||||
|
"$$POWERSHELL" -NoProfile -Command "Compress-Archive -LiteralPath '$(WINDOWS_SERVER_BINARY)' -DestinationPath '$(RELEASE_DIR)/$$asset' -Force"
|
||||||
|
|
||||||
|
release-assets-server-macos: version-check
|
||||||
|
mkdir -p build "$(RELEASE_DIR)"
|
||||||
|
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-macos-amd64 ./cmd/server
|
||||||
|
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-macos-arm64 ./cmd/server
|
||||||
|
lipo -create build/nex-server-macos-amd64 build/nex-server-macos-arm64 -output build/nex-server-macos-universal
|
||||||
|
lipo -info build/nex-server-macos-universal | grep -q 'x86_64 arm64'
|
||||||
|
asset=$$(go run ./versionctl asset-name server macos amd64 tar.gz); \
|
||||||
|
tar -C build -czf "$(RELEASE_DIR)/$$asset" nex-server-macos-amd64
|
||||||
|
asset=$$(go run ./versionctl asset-name server macos arm64 tar.gz); \
|
||||||
|
tar -C build -czf "$(RELEASE_DIR)/$$asset" nex-server-macos-arm64
|
||||||
|
asset=$$(go run ./versionctl asset-name server macos universal tar.gz); \
|
||||||
|
tar -C build -czf "$(RELEASE_DIR)/$$asset" nex-server-macos-universal
|
||||||
|
rm -f build/nex-server-macos-amd64 build/nex-server-macos-arm64 build/nex-server-macos-universal
|
||||||
|
|
||||||
|
release-assets-desktop-linux: version-check release-assets-check desktop-build-linux _package-linux-tar _package-linux-appimage _package-linux-deb _package-linux-rpm
|
||||||
|
|
||||||
|
release-assets-desktop-windows: version-check release-assets-check desktop-build-win
|
||||||
|
mkdir -p "$(RELEASE_DIR)"
|
||||||
|
asset=$$(go run ./versionctl asset-name desktop windows $(TARGET_ARCH) zip); \
|
||||||
|
if command -v powershell.exe >/dev/null 2>&1; then POWERSHELL=powershell.exe; else POWERSHELL=powershell; fi; \
|
||||||
|
"$$POWERSHELL" -NoProfile -Command "Compress-Archive -LiteralPath '$(WINDOWS_DESKTOP_BINARY)' -DestinationPath '$(RELEASE_DIR)/$$asset' -Force"
|
||||||
|
|
||||||
|
release-assets-desktop-macos: version-check release-assets-check desktop-build-mac _package-macos-zip _package-macos-dmg
|
||||||
|
rm -rf build/Nex.app build/dmg
|
||||||
|
|
||||||
|
release-assets-checksums:
|
||||||
|
@cd "$(RELEASE_DIR)" && \
|
||||||
|
rm -f SHA256SUMS && \
|
||||||
|
for asset in *; do \
|
||||||
|
[ -f "$$asset" ] || continue; \
|
||||||
|
if command -v sha256sum >/dev/null 2>&1; then \
|
||||||
|
sha256sum "$$asset"; \
|
||||||
|
elif command -v shasum >/dev/null 2>&1; then \
|
||||||
|
shasum -a 256 "$$asset"; \
|
||||||
|
else \
|
||||||
|
printf 'Missing sha256sum or shasum\n' >&2; \
|
||||||
|
exit 1; \
|
||||||
|
fi; \
|
||||||
|
done > SHA256SUMS && \
|
||||||
|
test -s SHA256SUMS
|
||||||
|
|
||||||
|
_check-linux-target-arch:
|
||||||
|
@if [ "$(TARGET_ARCH)" != "amd64" ] && [ "$(TARGET_ARCH)" != "arm64" ]; then \
|
||||||
|
printf 'Unsupported Linux TARGET_ARCH: %s\n' "$(TARGET_ARCH)"; \
|
||||||
|
exit 1; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
_check-windows-target-arch:
|
||||||
|
@if [ "$(TARGET_ARCH)" != "amd64" ]; then \
|
||||||
|
printf 'Unsupported Windows TARGET_ARCH: %s\n' "$(TARGET_ARCH)"; \
|
||||||
|
exit 1; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
_ensure-appimagetool:
|
||||||
|
@mkdir -p build/tools
|
||||||
|
@if [ ! -x "$(APPIMAGETOOL)" ]; then \
|
||||||
|
printf 'Downloading appimagetool for %s...\n' "$(APPIMAGE_ARCH)"; \
|
||||||
|
command -v curl >/dev/null 2>&1 || { printf 'Missing curl for appimagetool download\n'; exit 1; }; \
|
||||||
|
curl -L "$(APPIMAGETOOL_URL)" -o "$(APPIMAGETOOL)"; \
|
||||||
|
chmod +x "$(APPIMAGETOOL)"; \
|
||||||
|
fi; \
|
||||||
|
printf 'Using appimagetool: %s\n' "$(APPIMAGETOOL)"
|
||||||
|
|
||||||
|
_package-linux-tar:
|
||||||
|
mkdir -p "$(RELEASE_DIR)"
|
||||||
|
asset=$$(go run ./versionctl asset-name desktop linux $(TARGET_ARCH) tar.gz); \
|
||||||
|
tar -C build -czf "$(RELEASE_DIR)/$$asset" nex-linux-$(TARGET_ARCH)
|
||||||
|
|
||||||
|
_package-linux-appimage: _ensure-appimagetool
|
||||||
|
mkdir -p "$(RELEASE_DIR)"
|
||||||
|
appdir="build/appimage/nex-$(TARGET_ARCH).AppDir"; \
|
||||||
|
rm -rf "$$appdir"; \
|
||||||
|
mkdir -p "$$appdir/usr/bin" "$$appdir/usr/share/applications" "$$appdir/usr/share/icons"; \
|
||||||
|
install -m 0755 "$(LINUX_DESKTOP_BINARY)" "$$appdir/usr/bin/nex"; \
|
||||||
|
install -m 0644 packaging/linux/nex.desktop "$$appdir/nex.desktop"; \
|
||||||
|
install -m 0644 packaging/linux/nex.desktop "$$appdir/usr/share/applications/nex.desktop"; \
|
||||||
|
install -m 0755 packaging/linux/AppRun "$$appdir/AppRun"; \
|
||||||
|
cp -R assets/icons/hicolor "$$appdir/usr/share/icons/"; \
|
||||||
|
cp assets/icon.png "$$appdir/nex.png"; \
|
||||||
|
asset=$$(go run ./versionctl asset-name desktop linux $(TARGET_ARCH) AppImage); \
|
||||||
|
ARCH=$(APPIMAGE_ARCH) APPIMAGE_EXTRACT_AND_RUN=1 "$(APPIMAGETOOL)" "$$appdir" "$(RELEASE_DIR)/$$asset"; \
|
||||||
|
chmod +x "$(RELEASE_DIR)/$$asset"; \
|
||||||
|
test -s "$(RELEASE_DIR)/$$asset"
|
||||||
|
|
||||||
|
_package-linux-deb:
|
||||||
|
mkdir -p "$(RELEASE_DIR)"
|
||||||
|
pkgdir="build/pkg/deb/nex-$(TARGET_ARCH)"; \
|
||||||
|
rm -rf "$$pkgdir"; \
|
||||||
|
mkdir -p "$$pkgdir/DEBIAN" "$$pkgdir/usr/bin" "$$pkgdir/usr/share/applications" "$$pkgdir/usr/share/icons"; \
|
||||||
|
install -m 0755 "$(LINUX_DESKTOP_BINARY)" "$$pkgdir/usr/bin/nex"; \
|
||||||
|
install -m 0644 packaging/linux/nex.desktop "$$pkgdir/usr/share/applications/nex.desktop"; \
|
||||||
|
cp -R assets/icons/hicolor "$$pkgdir/usr/share/icons/"; \
|
||||||
|
printf '%s\n' \
|
||||||
|
'Package: nex' \
|
||||||
|
'Version: $(VERSION)' \
|
||||||
|
'Section: utils' \
|
||||||
|
'Priority: optional' \
|
||||||
|
'Architecture: $(DEB_ARCH)' \
|
||||||
|
'Maintainer: Nex Maintainers <noreply@example.com>' \
|
||||||
|
'Depends: libgtk-3-0, libayatana-appindicator3-1, xdg-utils' \
|
||||||
|
'Description: AI Gateway desktop application' \
|
||||||
|
' Nex is an AI Gateway desktop application.' \
|
||||||
|
> "$$pkgdir/DEBIAN/control"; \
|
||||||
|
asset=$$(go run ./versionctl asset-name desktop linux $(TARGET_ARCH) deb); \
|
||||||
|
dpkg-deb --build --root-owner-group "$$pkgdir" "$(RELEASE_DIR)/$$asset"; \
|
||||||
|
dpkg-deb -I "$(RELEASE_DIR)/$$asset" >/dev/null
|
||||||
|
|
||||||
|
_package-linux-rpm:
|
||||||
|
mkdir -p "$(RELEASE_DIR)"
|
||||||
|
topdir="$(abspath build/rpmbuild-$(TARGET_ARCH))"; \
|
||||||
|
rm -rf "$$topdir"; \
|
||||||
|
mkdir -p "$$topdir/BUILD" "$$topdir/BUILDROOT" "$$topdir/RPMS" "$$topdir/SOURCES" "$$topdir/SPECS" "$$topdir/SRPMS"; \
|
||||||
|
rpmbuild -bb --target "$(RPM_ARCH)" \
|
||||||
|
--define "_topdir $$topdir" \
|
||||||
|
--define "nex_version $(VERSION)" \
|
||||||
|
--define "nex_binary $(abspath $(LINUX_DESKTOP_BINARY))" \
|
||||||
|
--define "nex_desktop_file $(abspath packaging/linux/nex.desktop)" \
|
||||||
|
--define "nex_icons_dir $(abspath assets/icons/hicolor)" \
|
||||||
|
packaging/linux/nex.spec; \
|
||||||
|
rpm_file=$$(find "$$topdir/RPMS" -type f -name '*.rpm' | sort | tail -n 1); \
|
||||||
|
test -n "$$rpm_file"; \
|
||||||
|
asset=$$(go run ./versionctl asset-name desktop linux $(TARGET_ARCH) rpm); \
|
||||||
|
cp "$$rpm_file" "$(RELEASE_DIR)/$$asset"; \
|
||||||
|
rpm -qip "$(RELEASE_DIR)/$$asset" >/dev/null
|
||||||
|
|
||||||
|
_package-macos-zip:
|
||||||
|
mkdir -p "$(RELEASE_DIR)"
|
||||||
|
asset=$$(go run ./versionctl asset-name desktop macos universal zip); \
|
||||||
|
ditto -c -k --keepParent build/Nex.app "$(RELEASE_DIR)/$$asset"
|
||||||
|
|
||||||
|
_package-macos-dmg:
|
||||||
|
mkdir -p "$(RELEASE_DIR)"
|
||||||
|
dmgdir="build/dmg/Nex"; \
|
||||||
|
rm -rf "$$dmgdir"; \
|
||||||
|
mkdir -p "$$dmgdir"; \
|
||||||
|
cp -R build/Nex.app "$$dmgdir/Nex.app"; \
|
||||||
|
ln -s /Applications "$$dmgdir/Applications"; \
|
||||||
|
asset=$$(go run ./versionctl asset-name desktop macos universal dmg); \
|
||||||
|
hdiutil create -volname Nex -srcfolder "$$dmgdir" -ov -format UDZO "$(RELEASE_DIR)/$$asset"; \
|
||||||
|
hdiutil verify "$(RELEASE_DIR)/$$asset" >/dev/null && \
|
||||||
|
rm -rf "$$dmgdir"
|
||||||
|
|
||||||
|
# ============================================
|
||||||
|
# 共享 helper targets
|
||||||
|
# ============================================
|
||||||
|
|
||||||
|
_backend-build:
|
||||||
|
@$(MAKE) -C backend build
|
||||||
|
|
||||||
|
_backend-lint:
|
||||||
|
@$(MAKE) -C backend lint
|
||||||
|
|
||||||
|
_backend-test:
|
||||||
|
@$(MAKE) -C backend test
|
||||||
|
|
||||||
|
_backend-clean:
|
||||||
|
@$(MAKE) -C backend clean
|
||||||
|
|
||||||
|
_versionctl-lint:
|
||||||
|
@$(MAKE) -C versionctl lint
|
||||||
|
|
||||||
|
_versionctl-test:
|
||||||
|
@$(MAKE) -C versionctl test
|
||||||
|
|
||||||
|
_frontend-install:
|
||||||
|
cd frontend && bun install
|
||||||
|
|
||||||
|
_frontend-build: _frontend-install
|
||||||
|
cd frontend && bun run build
|
||||||
|
|
||||||
|
_frontend-check: _frontend-install
|
||||||
|
cd frontend && bun run check
|
||||||
|
|
||||||
|
_frontend-test: _frontend-install
|
||||||
|
cd frontend && bun run test
|
||||||
|
|
||||||
|
_frontend-dev: _frontend-install
|
||||||
|
cd frontend && bun run dev
|
||||||
|
|
||||||
|
_frontend-clean:
|
||||||
|
rm -rf frontend/dist frontend/.next frontend/coverage frontend/playwright-report frontend/test-results frontend/tsconfig.tsbuildinfo
|
||||||
|
|||||||
293
README.md
293
README.md
@@ -27,7 +27,7 @@ nex/
|
|||||||
│ │ ├── api/ # API 层(统一请求封装 + 字段转换)
|
│ │ ├── api/ # API 层(统一请求封装 + 字段转换)
|
||||||
│ │ ├── hooks/ # TanStack Query hooks
|
│ │ ├── hooks/ # TanStack Query hooks
|
||||||
│ │ ├── components/ # 通用组件(AppLayout)
|
│ │ ├── components/ # 通用组件(AppLayout)
|
||||||
│ │ ├── pages/ # 页面(Providers, Stats)
|
│ │ ├── pages/ # 页面(Providers, Stats, Settings)
|
||||||
│ │ ├── routes/ # React Router 路由配置
|
│ │ ├── routes/ # React Router 路由配置
|
||||||
│ │ ├── types/ # TypeScript 类型定义
|
│ │ ├── types/ # TypeScript 类型定义
|
||||||
│ │ └── __tests__/ # 单元测试 + 组件测试
|
│ │ └── __tests__/ # 单元测试 + 组件测试
|
||||||
@@ -36,12 +36,10 @@ nex/
|
|||||||
│
|
│
|
||||||
├── assets/ # 应用资源
|
├── assets/ # 应用资源
|
||||||
│ ├── icon.png # 托盘图标
|
│ ├── icon.png # 托盘图标
|
||||||
│ ├── AppIcon.icns # macOS 应用图标
|
│ ├── icon.icns # macOS 应用图标
|
||||||
│ └── icon.ico # Windows 应用图标
|
│ └── icon.ico # Windows 应用图标
|
||||||
│
|
│
|
||||||
├── scripts/ # 构建脚本
|
├── packaging/ # 桌面发布包元数据(Linux desktop entry、RPM spec 等)
|
||||||
│ └── build/
|
|
||||||
│ └── package-macos.sh # macOS .app 打包脚本
|
|
||||||
│
|
│
|
||||||
└── README.md # 本文件
|
└── README.md # 本文件
|
||||||
```
|
```
|
||||||
@@ -51,7 +49,7 @@ nex/
|
|||||||
- **双协议支持**:同时支持 OpenAI 和 Anthropic 协议
|
- **双协议支持**:同时支持 OpenAI 和 Anthropic 协议
|
||||||
- **跨协议转换**:Hub-and-Spoke 架构实现 OpenAI ↔ Anthropic 双向转换
|
- **跨协议转换**:Hub-and-Spoke 架构实现 OpenAI ↔ Anthropic 双向转换
|
||||||
- **统一模型 ID**:`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4`)
|
- **统一模型 ID**:`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4`)
|
||||||
- **Smart Passthrough**:同协议请求零序列化开销,仅改写 model 字段
|
- **Smart Passthrough**:同协议请求跳过 Canonical 全量转换,仅在 JSON 层改写 model 字段
|
||||||
- **流式响应**:完整支持 SSE 流式传输,包括跨协议流式转换
|
- **流式响应**:完整支持 SSE 流式传输,包括跨协议流式转换
|
||||||
- **Function Calling**:支持工具调用(Tools)
|
- **Function Calling**:支持工具调用(Tools)
|
||||||
- **Thinking / Reasoning**:支持 OpenAI `reasoning_effort` 和 Anthropic `thinking` 配置
|
- **Thinking / Reasoning**:支持 OpenAI `reasoning_effort` 和 Anthropic `thinking` 配置
|
||||||
@@ -59,6 +57,7 @@ nex/
|
|||||||
- **多供应商管理**:配置和管理多个供应商(供应商 ID 仅限字母、数字、下划线)
|
- **多供应商管理**:配置和管理多个供应商(供应商 ID 仅限字母、数字、下划线)
|
||||||
- **用量统计**:按供应商、模型、日期统计请求数量
|
- **用量统计**:按供应商、模型、日期统计请求数量
|
||||||
- **Web 配置界面**:提供供应商和模型配置管理
|
- **Web 配置界面**:提供供应商和模型配置管理
|
||||||
|
- **启动参数设置**:通过 Web 界面查看和编辑启动参数(Desktop 可编辑、Server 只读)
|
||||||
|
|
||||||
## 技术栈
|
## 技术栈
|
||||||
|
|
||||||
@@ -66,12 +65,26 @@ nex/
|
|||||||
- **语言**: Go 1.26+
|
- **语言**: Go 1.26+
|
||||||
- **HTTP 框架**: Gin
|
- **HTTP 框架**: Gin
|
||||||
- **ORM**: GORM
|
- **ORM**: GORM
|
||||||
- **数据库**: SQLite
|
- **数据库**: SQLite / MySQL
|
||||||
- **日志**: zap + lumberjack(结构化日志 + 日志轮转)
|
- **日志**: zap + lumberjack(结构化日志 + 日志轮转 + 模块标识)
|
||||||
- **配置**: Viper + pflag(多层配置:CLI > 环境变量 > 配置文件 > 默认值)
|
- **配置**: Viper + pflag(Server 多层配置,Desktop 配置文件快照)
|
||||||
- **验证**: go-playground/validator/v10
|
- **验证**: go-playground/validator/v10
|
||||||
- **迁移**: goose
|
- **迁移**: goose
|
||||||
|
|
||||||
|
#### 日志模块标识规范
|
||||||
|
|
||||||
|
每个模块通过依赖注入获取带模块标识的 logger,日志输出格式为 `[module.name]`:
|
||||||
|
|
||||||
|
```
|
||||||
|
Console: INFO [handler.proxy] 处理请求 method=POST path=/v1/chat
|
||||||
|
JSON: {"level":"info","logger":"handler.proxy","msg":"处理请求","method":"POST"}
|
||||||
|
```
|
||||||
|
|
||||||
|
模块命名规范:
|
||||||
|
- 单一职责包:`database`、`config`
|
||||||
|
- 多实体包:`handler.proxy`、`service.provider`
|
||||||
|
- 子包:`handler.middleware`
|
||||||
|
|
||||||
### 前端
|
### 前端
|
||||||
- **运行时**: Bun
|
- **运行时**: Bun
|
||||||
- **构建工具**: Vite
|
- **构建工具**: Vite
|
||||||
@@ -81,7 +94,7 @@ nex/
|
|||||||
- **图表库**: Recharts
|
- **图表库**: Recharts
|
||||||
- **路由**: React Router v7
|
- **路由**: React Router v7
|
||||||
- **数据获取**: TanStack Query v5
|
- **数据获取**: TanStack Query v5
|
||||||
- **样式**: SCSS Modules
|
- **样式**: TDesign 组件 props 优先,TDesign tokens 次之,SCSS 作为兜底补充
|
||||||
- **测试**: Vitest + React Testing Library + Playwright
|
- **测试**: Vitest + React Testing Library + Playwright
|
||||||
|
|
||||||
## 快速开始
|
## 快速开始
|
||||||
@@ -91,19 +104,21 @@ nex/
|
|||||||
**构建桌面应用**:
|
**构建桌面应用**:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# macOS (arm64 + amd64)
|
# macOS (arm64 + amd64,并打包为 .app)
|
||||||
make desktop-mac
|
make desktop-build-mac
|
||||||
make package-macos # 打包为 .app
|
|
||||||
|
|
||||||
# Windows
|
# Windows
|
||||||
make desktop-win
|
make desktop-build-win
|
||||||
|
|
||||||
# Linux
|
# Linux
|
||||||
make desktop-linux
|
make desktop-build-linux
|
||||||
|
|
||||||
|
# Linux arm64
|
||||||
|
make desktop-build-linux TARGET_ARCH=arm64
|
||||||
```
|
```
|
||||||
|
|
||||||
**使用桌面应用**:
|
**使用桌面应用**:
|
||||||
- 双击启动应用(macOS: Nex.app,Windows: nex-win-amd64.exe,Linux: nex-linux-amd64)
|
- 双击启动应用(macOS: Nex.app,Windows: nex-win-amd64.exe,Linux: nex-linux-amd64 / nex-linux-arm64)
|
||||||
- 系统托盘图标出现,浏览器自动打开管理界面
|
- 系统托盘图标出现,浏览器自动打开管理界面
|
||||||
- 点击托盘图标显示菜单,可打开管理界面或退出
|
- 点击托盘图标显示菜单,可打开管理界面或退出
|
||||||
- 关闭浏览器后服务继续运行,可通过托盘重新打开
|
- 关闭浏览器后服务继续运行,可通过托盘重新打开
|
||||||
@@ -111,8 +126,10 @@ make desktop-linux
|
|||||||
**注意事项**:
|
**注意事项**:
|
||||||
- 桌面应用需要 CGO 支持
|
- 桌面应用需要 CGO 支持
|
||||||
- macOS: 自带 Xcode Command Line Tools
|
- macOS: 自带 Xcode Command Line Tools
|
||||||
- Linux: 自带 gcc,部分桌面环境需要 `libappindicator3-dev`
|
- Linux 构建: 需要 gcc、pkg-config、GTK3 开发包和 Ayatana AppIndicator 开发包(Ubuntu/Debian: `libgtk-3-dev`、`libayatana-appindicator3-dev`)
|
||||||
- Windows: 需要 MinGW-w64 或在 Windows 环境构建
|
- Linux 运行: 需要 GTK3、Ayatana AppIndicator 和 xdg-utils;启动失败提示会 best-effort 使用 `notify-send`、`kdialog`、`zenity` 或 `xmessage`,这些通知/弹窗工具为软依赖,缺失时会降级到标准错误输出或日志;AppImage 也依赖系统提供 AppImage runtime/FUSE 能力,不承诺完全自包含
|
||||||
|
- Windows: 需要对应架构的 MinGW-w64/MSYS2 工具链,desktop 使用 GUI linker flags 隐藏控制台窗口
|
||||||
|
- macOS DMG: 发布包暂不签名、不 notarize,首次打开可能出现 Gatekeeper 提示
|
||||||
|
|
||||||
**Linux 桌面环境兼容性**:
|
**Linux 桌面环境兼容性**:
|
||||||
- GNOME: 需要 AppIndicator 扩展
|
- GNOME: 需要 AppIndicator 扩展
|
||||||
@@ -120,50 +137,87 @@ make desktop-linux
|
|||||||
- Xfce: 需要 libappindicator
|
- Xfce: 需要 libappindicator
|
||||||
- 其他支持 StatusNotifierItem 规范的环境
|
- 其他支持 StatusNotifierItem 规范的环境
|
||||||
|
|
||||||
### CLI 模式
|
### Server 模式(前后端分离)
|
||||||
|
|
||||||
#### 后端
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd backend
|
make server-run
|
||||||
go mod download
|
|
||||||
go run cmd/server/main.go
|
|
||||||
```
|
```
|
||||||
|
|
||||||
后端服务将在 `http://localhost:9826` 启动。首次启动会自动:
|
`make server-run` 会并行启动:
|
||||||
- 创建配置文件 `~/.nex/config.yaml`
|
- 后端服务:`http://localhost:9826`
|
||||||
|
- 前端开发服务器:`http://localhost:5173`
|
||||||
|
|
||||||
|
前端请求会继续通过 Vite proxy 转发到后端。后端首次启动会自动:
|
||||||
- 初始化数据库 `~/.nex/config.db`
|
- 初始化数据库 `~/.nex/config.db`
|
||||||
- 运行数据库迁移
|
- 运行数据库迁移
|
||||||
- 创建日志目录 `~/.nex/log/`
|
- 创建日志目录 `~/.nex/log/`
|
||||||
|
|
||||||
### 前端
|
**构建 server 模式产物**:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd frontend
|
make server-build
|
||||||
bun install
|
|
||||||
bun dev
|
|
||||||
```
|
```
|
||||||
|
|
||||||
前端开发服务器将在 `http://localhost:5173` 启动,API 请求通过 Vite proxy 转发到后端。
|
### Release 产物
|
||||||
|
|
||||||
|
发布流程由 Git tag `vX.Y.Z` 触发,GitHub Actions 会先通过全流程测试门禁,再构建并创建 Draft Release,上传 server、web 和 desktop 三类产物,同时生成 `SHA256SUMS`。
|
||||||
|
|
||||||
|
**server 产物**(不内置 Web 管理界面):
|
||||||
|
|
||||||
|
| 平台 | 产物 |
|
||||||
|
|------|------|
|
||||||
|
| Linux amd64 | `nex-server_<version>_linux_amd64.tar.gz` |
|
||||||
|
| Linux arm64 | `nex-server_<version>_linux_arm64.tar.gz` |
|
||||||
|
| macOS amd64 | `nex-server_<version>_macos_amd64.tar.gz` |
|
||||||
|
| macOS arm64 | `nex-server_<version>_macos_arm64.tar.gz` |
|
||||||
|
| macOS universal | `nex-server_<version>_macos_universal.tar.gz` |
|
||||||
|
| Windows amd64 | `nex-server_<version>_windows_amd64.zip` |
|
||||||
|
|
||||||
|
**web 产物**:
|
||||||
|
|
||||||
|
| 内容 | 产物 |
|
||||||
|
|------|------|
|
||||||
|
| `frontend/dist` | `nex-web_<version>.tar.gz` |
|
||||||
|
|
||||||
|
**desktop 产物**:
|
||||||
|
|
||||||
|
| 平台 | 产物 |
|
||||||
|
|------|------|
|
||||||
|
| Linux amd64 | `nex-desktop_<version>_linux_amd64.tar.gz`、`.AppImage`、`.deb`、`.rpm` |
|
||||||
|
| Linux arm64 | `nex-desktop_<version>_linux_arm64.tar.gz`、`.AppImage`、`.deb`、`.rpm` |
|
||||||
|
| macOS universal | `nex-desktop_<version>_macos_universal.zip`、`nex-desktop_<version>_macos_universal.dmg` |
|
||||||
|
| Windows amd64 | `nex-desktop_<version>_windows_amd64.zip` |
|
||||||
|
|
||||||
|
Linux deb 包声明 `libgtk-3-0`、`libayatana-appindicator3-1`、`xdg-utils` 运行依赖;rpm 包声明 `gtk3`、`libayatana-appindicator-gtk3`、`xdg-utils` 运行依赖。Rocky Linux 9 等发行版可能需要启用 EPEL 才能解析 Ayatana AppIndicator 依赖。
|
||||||
|
|
||||||
|
server 和 desktop 发布产物自包含运行时数据库迁移资源(通过 `go:embed` 嵌入二进制),安装后首次启动不再依赖仓库源码目录。
|
||||||
|
|
||||||
## API 接口
|
## API 接口
|
||||||
|
|
||||||
### 代理接口(对外部应用)
|
### 代理接口(对外部应用)
|
||||||
|
|
||||||
代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough,最小化 JSON 改写保持参数保真;跨协议请求走完整 decode/encode 转换。
|
代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough,最小化 JSON 改写并保持未改写字段的 JSON 内容和类型不变;跨协议请求走完整 decode/encode 转换。
|
||||||
|
|
||||||
**OpenAI 协议**(`protocol=openai`):
|
**OpenAI 协议**(`protocol=openai`):
|
||||||
- `POST /openai/chat/completions` - 对话补全
|
- `POST /openai/v1/chat/completions` - 对话补全
|
||||||
- `GET /openai/models` - 模型列表(本地数据库聚合)
|
- `GET /openai/v1/models` - 模型列表(本地数据库聚合)
|
||||||
- `GET /openai/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
- `GET /openai/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
||||||
- `POST /openai/embeddings` - 嵌入
|
- `POST /openai/v1/embeddings` - 嵌入
|
||||||
- `POST /openai/rerank` - 重排序
|
- `POST /openai/v1/rerank` - 重排序
|
||||||
|
|
||||||
**Anthropic 协议**(`protocol=anthropic`):
|
**Anthropic 协议**(`protocol=anthropic`):
|
||||||
- `POST /anthropic/v1/messages` - 消息对话
|
- `POST /anthropic/v1/messages` - 消息对话
|
||||||
- `GET /anthropic/v1/models` - 模型列表(本地数据库聚合)
|
- `GET /anthropic/v1/models` - 模型列表(本地数据库聚合)
|
||||||
- `GET /anthropic/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
- `GET /anthropic/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
||||||
|
|
||||||
|
路径边界:网关只剥离第一段协议前缀,剩余路径保持协议原生形态交给 adapter。OpenAI adapter 接收 `/v1/chat/completions`、`/v1/models`、`/v1/embeddings`、`/v1/rerank`,并在构建上游 URL 时去掉 `/v1`;Anthropic adapter 接收 `/v1/messages`、`/v1/models`。因此 OpenAI 供应商 `base_url` 配置到版本路径一级(如 `https://api.openai.com/v1`),Anthropic 供应商 `base_url` 配置到域名级(如 `https://api.anthropic.com`)。
|
||||||
|
|
||||||
|
代理错误边界:网关层错误统一返回 `{"error":"...","code":"..."}`,例如 `INVALID_JSON`、`MODEL_NOT_FOUND`、`CONVERSION_FAILED`、`UPSTREAM_UNAVAILABLE`。只要上游已经返回 HTTP 响应,非 2xx 的 status、过滤 hop-by-hop header 后的 headers 和 body 会直接透传,不包装为应用错误或协议错误。
|
||||||
|
|
||||||
|
模型路由边界:只有 adapter 明确适配的接口会解析请求体中的 `model` 并使用统一模型 ID 路由;未知接口即使包含顶层 `model` 也按无 model 透传处理。
|
||||||
|
|
||||||
|
流式边界:同协议无响应 model 改写时原样透传 SSE frame 和 `[DONE]`;同协议需要响应 model 改写时只解析 SSE frame 的 `data` JSON 并改写 `model`;跨协议流式仍走 provider decoder → Canonical stream event → client encoder。
|
||||||
|
|
||||||
### 管理接口(对前端)
|
### 管理接口(对前端)
|
||||||
|
|
||||||
#### 供应商管理
|
#### 供应商管理
|
||||||
@@ -186,13 +240,29 @@ bun dev
|
|||||||
|
|
||||||
查询参数支持:`provider_id`、`model_name`、`start_date`、`end_date`、`group_by`
|
查询参数支持:`provider_id`、`model_name`、`start_date`、`end_date`、`group_by`
|
||||||
|
|
||||||
|
#### 启动参数设置
|
||||||
|
- `GET /api/settings/startup` - 查询启动参数设置
|
||||||
|
- `PUT /api/settings/startup` - 保存启动参数设置(仅 Desktop 模式)
|
||||||
|
|
||||||
|
**行为差异**:
|
||||||
|
- **Desktop 模式**:查询返回配置文件编辑视图(`~/.nex/config.yaml` + 默认值),允许保存到配置文件,保存后当前运行服务不受影响,需重启 Desktop 生效
|
||||||
|
- **Server 模式**:查询返回当前运行有效配置,保存请求始终返回 403
|
||||||
|
|
||||||
|
响应包含 `mode`、`editable`、`config_path`、`restart_required` 元数据和完整启动参数配置。Duration 字段使用 Go `time.Duration.String()` 标准字符串格式(如 `30s`、`1m0s`、`1h0m0s`);配置文件中用户可手写任意合法 Go duration 字符串(如 `1h`、`30m`),保存时系统会统一为标准格式。
|
||||||
|
|
||||||
|
#### 版本信息
|
||||||
|
- `GET /api/version` - 获取后端构建版本信息(`version`、`commit`、`build_time`),用于前端 About 页面诊断前后端版本一致性
|
||||||
|
|
||||||
## 配置
|
## 配置
|
||||||
|
|
||||||
配置支持多种方式,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值**
|
配置方式取决于启动模式:
|
||||||
|
|
||||||
|
- **Server 模式**(`cmd/server`):支持 CLI 参数 > 环境变量 > 配置文件 > 默认值
|
||||||
|
- **Desktop 模式**(`cmd/desktop`):仅支持配置文件 `~/.nex/config.yaml` > 默认值,修改配置文件后需重启 desktop 生效
|
||||||
|
|
||||||
### 配置文件
|
### 配置文件
|
||||||
|
|
||||||
配置文件位于 `~/.nex/config.yaml`,首次启动自动生成:
|
配置文件位于 `~/.nex/config.yaml`。配置文件不存在时使用默认值,不会自动生成;需要自定义时手动创建该文件:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
server:
|
server:
|
||||||
@@ -201,10 +271,17 @@ server:
|
|||||||
write_timeout: 30s
|
write_timeout: 30s
|
||||||
|
|
||||||
database:
|
database:
|
||||||
path: ~/.nex/config.db
|
driver: sqlite # sqlite 或 mysql
|
||||||
|
path: ~/.nex/config.db # SQLite 数据库文件路径
|
||||||
|
# --- MySQL 配置(driver=mysql 时生效)---
|
||||||
|
# host: localhost
|
||||||
|
# port: 3306
|
||||||
|
# user: nex
|
||||||
|
# password: ""
|
||||||
|
# dbname: nex
|
||||||
max_idle_conns: 10
|
max_idle_conns: 10
|
||||||
max_open_conns: 100
|
max_open_conns: 100
|
||||||
conn_max_lifetime: 1h
|
conn_max_lifetime: 1h0m0s
|
||||||
|
|
||||||
log:
|
log:
|
||||||
level: info
|
level: info
|
||||||
@@ -215,19 +292,31 @@ log:
|
|||||||
compress: true
|
compress: true
|
||||||
```
|
```
|
||||||
|
|
||||||
### 环境变量
|
### 环境变量(仅 Server 模式)
|
||||||
|
|
||||||
所有配置项支持环境变量,使用 `NEX_` 前缀:
|
Server 模式下,所有配置项支持环境变量,使用 `NEX_` 前缀:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export NEX_SERVER_PORT=9000
|
export NEX_SERVER_PORT=9000
|
||||||
export NEX_DATABASE_PATH=/data/nex.db
|
export NEX_DATABASE_PATH=/data/nex.db
|
||||||
export NEX_LOG_LEVEL=debug
|
export NEX_LOG_LEVEL=debug
|
||||||
|
|
||||||
|
# MySQL 模式
|
||||||
|
export NEX_DATABASE_DRIVER=mysql
|
||||||
|
export NEX_DATABASE_HOST=db.example.com
|
||||||
|
export NEX_DATABASE_PORT=3306
|
||||||
|
export NEX_DATABASE_USER=nex
|
||||||
|
export NEX_DATABASE_PASSWORD=secret
|
||||||
|
export NEX_DATABASE_DBNAME=nex
|
||||||
```
|
```
|
||||||
|
|
||||||
命名规则:配置路径转大写 + 下划线(如 `server.port` → `NEX_SERVER_PORT`)。
|
命名规则:配置路径转大写 + 下划线(如 `server.port` → `NEX_SERVER_PORT`)。
|
||||||
|
|
||||||
### CLI 参数
|
**Desktop 模式不支持环境变量覆盖。**Desktop 仅从 `~/.nex/config.yaml` 和默认值读取配置。
|
||||||
|
|
||||||
|
### CLI 参数(仅 Server 模式)
|
||||||
|
|
||||||
|
Server 模式下,支持命令行参数:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./server --server-port 9000 --log-level debug --database-path /tmp/test.db
|
./server --server-port 9000 --log-level debug --database-path /tmp/test.db
|
||||||
@@ -235,34 +324,124 @@ export NEX_LOG_LEVEL=debug
|
|||||||
|
|
||||||
命名规则:配置路径转 kebab-case(如 `server.port` → `--server-port`)。
|
命名规则:配置路径转 kebab-case(如 `server.port` → `--server-port`)。
|
||||||
|
|
||||||
|
**Desktop 不支持命令行参数覆盖配置。**Desktop 忽略所有 CLI 参数,仅从 `~/.nex/config.yaml` 读取。
|
||||||
|
|
||||||
### 数据文件
|
### 数据文件
|
||||||
|
|
||||||
- `~/.nex/config.yaml` - 配置文件
|
- `~/.nex/config.yaml` - 配置文件
|
||||||
- `~/.nex/config.db` - SQLite 数据库
|
- `~/.nex/config.db` - SQLite 数据库(MySQL 模式下不使用本地数据库文件)
|
||||||
- `~/.nex/log/` - 日志目录
|
- `~/.nex/log/` - 日志目录
|
||||||
|
|
||||||
## 测试
|
## 测试
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
make backend-test # 后端测试
|
# 全局默认测试(不含 MySQL 和前端 E2E)
|
||||||
make backend-test-coverage # 后端覆盖率
|
make test
|
||||||
make frontend-test # 前端测试
|
|
||||||
make frontend-test-e2e # 前端 E2E 测试
|
# 产品级测试
|
||||||
|
make server-test
|
||||||
|
make desktop-test
|
||||||
```
|
```
|
||||||
|
|
||||||
|
backend 分类测试、MySQL 专项测试和前端 E2E 测试请分别查看 `backend/README.md` 与 `frontend/README.md`。
|
||||||
|
|
||||||
## 开发
|
## 开发
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
make backend-build # 构建后端
|
# 首次克隆后安装 Git hooks
|
||||||
make backend-run # 运行后端
|
make hooks-install
|
||||||
make backend-lint # 后端代码检查
|
|
||||||
make backend-migrate-up # 数据库迁移
|
|
||||||
|
|
||||||
make frontend-build # 构建前端
|
# 检查 Git hooks 安装状态
|
||||||
make frontend-dev # 前端开发模式
|
make hooks-check
|
||||||
make frontend-lint # 前端代码检查
|
|
||||||
|
# 运行 Git hooks 回归测试
|
||||||
|
make hooks-test
|
||||||
|
|
||||||
|
# 全局命令
|
||||||
|
make lint # 前后端共享检查
|
||||||
|
make test # 默认全量测试(不含 MySQL/E2E)
|
||||||
|
make clean # 清理所有构建产物和测试报告
|
||||||
|
|
||||||
|
# server 模式
|
||||||
|
make server-run # 并行启动后端和前端开发服务
|
||||||
|
make server-build # 构建 backend/bin/server 和 frontend/dist
|
||||||
|
make server-lint # server 模式检查
|
||||||
|
make server-test # server 模式测试
|
||||||
|
make server-clean # 清理 server 模式产物
|
||||||
|
|
||||||
|
# desktop 模式
|
||||||
|
make desktop-build-mac # 构建 macOS 桌面应用
|
||||||
|
make desktop-build-win # 构建 Windows 桌面应用
|
||||||
|
make desktop-build-linux # 构建 Linux 桌面应用
|
||||||
|
make desktop-lint # desktop 模式检查
|
||||||
|
make desktop-test # desktop 专属测试
|
||||||
|
make desktop-clean # 清理 desktop 产物
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Git hooks 使用仓库内 `scripts/git-hooks/` 的原生脚本,不依赖额外 hook 框架。当前 hooks 包含:
|
||||||
|
|
||||||
|
- pre-commit:检查 staged files 的冲突标记、大文件告警和 LFS 指针,并按文件类型委托后端、versionctl、前端检查
|
||||||
|
- prepare-commit-msg:在编辑器打开时提供提交信息模板,辅助填写 `类型: 简短描述` 和多行说明
|
||||||
|
- commit-msg:校验提交信息格式为 `类型: 简短描述`,多行说明需在首行后保留空行;提交描述按项目规范使用中文,hook 不做字符集检测
|
||||||
|
|
||||||
|
## 版本与发布
|
||||||
|
|
||||||
|
### 统一版本源
|
||||||
|
|
||||||
|
- 仓库根目录 `VERSION` 是全仓唯一版本源,格式固定为 `x.y.z`
|
||||||
|
- `frontend/package.json` 和前端 `.env.*` 中的 `VITE_APP_VERSION` 由仓库工具同步,不能手工漂移
|
||||||
|
|
||||||
|
### 本地版本演进
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 递增版本(自动 lint + test + 工作区检查 + sync/check + commit + tag)
|
||||||
|
make version-bump BUMP=minor
|
||||||
|
|
||||||
|
# 或指定具体版本号
|
||||||
|
make version-bump SET_VERSION=1.0.0
|
||||||
|
|
||||||
|
# 推送到远程
|
||||||
|
git push --follow-tags
|
||||||
|
```
|
||||||
|
|
||||||
|
手动同步和校验:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make version-sync
|
||||||
|
make version-check
|
||||||
|
```
|
||||||
|
|
||||||
|
### 本地生成发布资产
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Linux: server + desktop
|
||||||
|
make release-assets-linux
|
||||||
|
|
||||||
|
# Windows: server + desktop(需在 Windows 环境执行)
|
||||||
|
make release-assets-windows
|
||||||
|
|
||||||
|
# macOS: darwin-amd64 server、darwin-arm64 server、desktop universal
|
||||||
|
make release-assets-macos
|
||||||
|
```
|
||||||
|
|
||||||
|
生成的版本化发布资产位于 `build/release/`。
|
||||||
|
|
||||||
|
### GitHub Draft Release
|
||||||
|
|
||||||
|
- 推送 `vX.Y.Z` tag 后,`.github/workflows/release.yml` 会自动执行发布流水线
|
||||||
|
- 流水线会先校验 tag 与 `VERSION` 一致,再执行全流程测试门禁(lint、默认测试、MySQL 测试、E2E 测试),测试不通过则阻止构建
|
||||||
|
- 测试通过后,三个平台 job 并行构建,各 job 会在正式构建前先检查 `go`、`bun` 和各自的平台打包工具链,缺失时快速失败并在日志中输出诊断信息
|
||||||
|
- Windows 发布 job 在 `MSYS2 / MINGW64` shell 中执行,并继承 `setup-go` / `setup-bun` 准备好的工具链路径
|
||||||
|
- 构建以下资产并上传到 GitHub Draft Release:
|
||||||
|
- Linux server
|
||||||
|
- Windows server
|
||||||
|
- darwin-amd64 server
|
||||||
|
- darwin-arm64 server
|
||||||
|
- Linux desktop
|
||||||
|
- Windows desktop
|
||||||
|
- macOS desktop universal
|
||||||
|
- Release 默认以 Draft 形式创建,需人工检查后再公开发布
|
||||||
|
|
||||||
## 开发规范
|
## 开发规范
|
||||||
|
|
||||||
详见各子项目的 README.md:
|
详见各子项目的 README.md:
|
||||||
@@ -271,4 +450,4 @@ make frontend-lint # 前端代码检查
|
|||||||
|
|
||||||
## 许可证
|
## 许可证
|
||||||
|
|
||||||
MIT
|
Apache License 2.0
|
||||||
|
|||||||
Binary file not shown.
@@ -1,64 +0,0 @@
|
|||||||
# Assets
|
|
||||||
|
|
||||||
应用资源文件目录。
|
|
||||||
|
|
||||||
## 文件说明
|
|
||||||
|
|
||||||
| 文件 | 用途 | 尺寸 | 格式 |
|
|
||||||
|------|------|------|------|
|
|
||||||
| `icon.svg` | 源图标 | 64x64 | SVG |
|
|
||||||
| `icon.png` | 托盘图标 | 64x64 | PNG |
|
|
||||||
| `AppIcon.icns` | macOS 应用图标 | 多尺寸 | ICNS |
|
|
||||||
| `icon.ico` | Windows 应用图标 | 256x256 | ICO |
|
|
||||||
|
|
||||||
## 替换图标
|
|
||||||
|
|
||||||
### 1. 准备图标
|
|
||||||
|
|
||||||
推荐使用 SVG 格式的源图标,尺寸至少 256x256。
|
|
||||||
|
|
||||||
### 2. 生成各平台图标
|
|
||||||
|
|
||||||
**托盘图标 (PNG)**:
|
|
||||||
```bash
|
|
||||||
magick your-icon.svg -resize 64x64 icon.png
|
|
||||||
```
|
|
||||||
|
|
||||||
**macOS 应用图标 (ICNS)**:
|
|
||||||
```bash
|
|
||||||
mkdir icon.iconset
|
|
||||||
magick your-icon.svg -resize 16x16 icon.iconset/icon_16x16.png
|
|
||||||
magick your-icon.svg -resize 32x32 icon.iconset/icon_16x16@2x.png
|
|
||||||
magick your-icon.svg -resize 32x32 icon.iconset/icon_32x32.png
|
|
||||||
magick your-icon.svg -resize 64x64 icon.iconset/icon_32x32@2x.png
|
|
||||||
magick your-icon.svg -resize 128x128 icon.iconset/icon_128x128.png
|
|
||||||
magick your-icon.svg -resize 256x256 icon.iconset/icon_128x128@2x.png
|
|
||||||
iconutil -c icns icon.iconset -o AppIcon.icns
|
|
||||||
rm -rf icon.iconset
|
|
||||||
```
|
|
||||||
|
|
||||||
**Windows 应用图标 (ICO)**:
|
|
||||||
```bash
|
|
||||||
magick your-icon.svg -resize 256x256 icon.ico
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 替换文件
|
|
||||||
|
|
||||||
将生成的文件放入此目录,然后重新构建桌面应用:
|
|
||||||
```bash
|
|
||||||
./scripts/build/build-darwin-arm64.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
## macOS Template 图标
|
|
||||||
|
|
||||||
macOS 支持 Template 图标,自动适配深浅色模式:
|
|
||||||
- 使用黑色 + 透明设计
|
|
||||||
- 文件名以 `Template` 结尾(如 `iconTemplate.png`)
|
|
||||||
- 黑色在深色模式下自动变为白色
|
|
||||||
|
|
||||||
## 设计建议
|
|
||||||
|
|
||||||
- 托盘图标应简洁,在小尺寸下清晰可辨
|
|
||||||
- 避免过多细节和文字
|
|
||||||
- 使用高对比度颜色
|
|
||||||
- macOS 建议使用 Template 图标风格
|
|
||||||
BIN
assets/icon.icns
LFS
Normal file
BIN
assets/icon.icns
LFS
Normal file
Binary file not shown.
BIN
assets/icon.ico
BIN
assets/icon.ico
Binary file not shown.
|
Before Width: | Height: | Size: 264 KiB After Width: | Height: | Size: 128 B |
BIN
assets/icon.png
BIN
assets/icon.png
Binary file not shown.
|
Before Width: | Height: | Size: 2.0 KiB After Width: | Height: | Size: 131 B |
@@ -1,13 +0,0 @@
|
|||||||
<svg width="64" height="64" viewBox="0 0 64 64" xmlns="http://www.w3.org/2000/svg">
|
|
||||||
<rect width="64" height="64" rx="12" fill="#4A90D9"/>
|
|
||||||
<polygon points="32,8 52,20 52,44 32,56 12,44 12,20" fill="none" stroke="white" stroke-width="3"/>
|
|
||||||
<circle cx="32" cy="32" r="6" fill="white"/>
|
|
||||||
<line x1="32" y1="32" x2="20" y2="20" stroke="white" stroke-width="2"/>
|
|
||||||
<line x1="32" y1="32" x2="44" y2="20" stroke="white" stroke-width="2"/>
|
|
||||||
<line x1="32" y1="32" x2="20" y2="44" stroke="white" stroke-width="2"/>
|
|
||||||
<line x1="32" y1="32" x2="44" y2="44" stroke="white" stroke-width="2"/>
|
|
||||||
<circle cx="20" cy="20" r="3" fill="white"/>
|
|
||||||
<circle cx="44" cy="20" r="3" fill="white"/>
|
|
||||||
<circle cx="20" cy="44" r="3" fill="white"/>
|
|
||||||
<circle cx="44" cy="44" r="3" fill="white"/>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 779 B |
BIN
assets/icons/hicolor/128x128/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/128x128/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/16x16/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/16x16/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/22x22/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/22x22/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/24x24/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/24x24/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/256x256/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/256x256/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/32x32/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/32x32/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/48x48/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/48x48/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/512x512/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/512x512/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/64x64/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/64x64/apps/nex.png
LFS
Normal file
Binary file not shown.
91
backend/.golangci.yml
Normal file
91
backend/.golangci.yml
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
run:
|
||||||
|
timeout: 5m
|
||||||
|
tests: true
|
||||||
|
|
||||||
|
linters:
|
||||||
|
disable-all: true
|
||||||
|
enable:
|
||||||
|
- forbidigo
|
||||||
|
- errorlint
|
||||||
|
- errcheck
|
||||||
|
- staticcheck
|
||||||
|
- revive
|
||||||
|
- gocritic
|
||||||
|
- gosec
|
||||||
|
- bodyclose
|
||||||
|
- noctx
|
||||||
|
- nilerr
|
||||||
|
- goimports
|
||||||
|
- gocyclo
|
||||||
|
|
||||||
|
linters-settings:
|
||||||
|
errcheck:
|
||||||
|
check-blank: true
|
||||||
|
check-type-assertions: true
|
||||||
|
exclude-functions:
|
||||||
|
- fmt.Fprintf
|
||||||
|
forbidigo:
|
||||||
|
analyze-types: true
|
||||||
|
forbid:
|
||||||
|
- p: '^fmt\.Print.*$'
|
||||||
|
msg: 使用 zap logger,不要直接输出到 stdout/stderr
|
||||||
|
- p: '^fmt\.Fprint.*$'
|
||||||
|
msg: 使用 zap logger,不要直接输出到 stdout/stderr
|
||||||
|
- p: '^log\.(Print|Println|Printf|Fatal|Fatalln|Fatalf|Panic|Panicln|Panicf)$'
|
||||||
|
msg: 使用 zap logger,不要使用标准库 log
|
||||||
|
- p: '^zap\.L$'
|
||||||
|
msg: 通过依赖注入传递 *zap.Logger,不要使用全局 logger
|
||||||
|
- p: '^zap\.S$'
|
||||||
|
msg: 不使用 Sugar logger
|
||||||
|
revive:
|
||||||
|
rules:
|
||||||
|
- name: exported
|
||||||
|
- name: var-naming
|
||||||
|
- name: indent-error-flow
|
||||||
|
- name: error-strings
|
||||||
|
- name: error-return
|
||||||
|
- name: blank-imports
|
||||||
|
- name: context-as-argument
|
||||||
|
- name: unexported-return
|
||||||
|
goimports:
|
||||||
|
local-prefixes: nex/backend
|
||||||
|
gocyclo:
|
||||||
|
min-complexity: 10
|
||||||
|
|
||||||
|
issues:
|
||||||
|
exclude-dirs:
|
||||||
|
- tests/mocks
|
||||||
|
exclude-generated: true
|
||||||
|
exclude-rules:
|
||||||
|
- path: '(_test\.go|tests/)'
|
||||||
|
linters:
|
||||||
|
- forbidigo
|
||||||
|
- path: '(_test\.go|tests/)'
|
||||||
|
linters:
|
||||||
|
- errcheck
|
||||||
|
source: '(^\s*_\s*=|,\s*_)'
|
||||||
|
- path: 'tests/integration/e2e_conversion_test\.go'
|
||||||
|
linters:
|
||||||
|
- errcheck
|
||||||
|
- path: '(_test\.go|tests/)'
|
||||||
|
linters:
|
||||||
|
- revive
|
||||||
|
text: '^exported:'
|
||||||
|
- path: '(_test\.go|tests/)'
|
||||||
|
linters:
|
||||||
|
- gosec
|
||||||
|
text: 'G(101|401|501)'
|
||||||
|
- path: '(_test\.go|tests/)'
|
||||||
|
linters:
|
||||||
|
- gocyclo
|
||||||
|
text: 'cyclomatic complexity (1[1-9]|20) of .* is high \(> 10\)'
|
||||||
|
- linters:
|
||||||
|
- revive
|
||||||
|
text: '(that stutters|BuildUrl should be BuildURL|ConvertHttpRequest should be ConvertHTTPRequest|ConvertHttpResponse should be ConvertHTTPResponse)'
|
||||||
|
- path: 'internal/conversion/.*\.go'
|
||||||
|
linters:
|
||||||
|
- gocyclo
|
||||||
|
- gocritic
|
||||||
|
- path: '(internal/provider/client\.go|internal/service/model_service_impl\.go|internal/service/stats_buffer\.go|internal/handler/proxy_handler\.go|cmd/(desktop|server)/main\.go)'
|
||||||
|
linters:
|
||||||
|
- gocyclo
|
||||||
97
backend/Makefile
Normal file
97
backend/Makefile
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
.PHONY: \
|
||||||
|
build run \
|
||||||
|
test test-unit test-integration test-coverage \
|
||||||
|
lint clean \
|
||||||
|
migrate-up migrate-down migrate-status migrate-create \
|
||||||
|
mysql-up mysql-down mysql-test mysql-test-quick
|
||||||
|
|
||||||
|
VERSION := $(shell go run ../versionctl print)
|
||||||
|
GIT_COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || printf 'unknown')
|
||||||
|
BUILD_TIME ?= $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
GO_LDFLAGS := -X nex/backend/pkg/buildinfo.version=$(VERSION) -X nex/backend/pkg/buildinfo.commit=$(GIT_COMMIT) -X nex/backend/pkg/buildinfo.buildTime=$(BUILD_TIME)
|
||||||
|
|
||||||
|
DB_DRIVER ?= sqlite3
|
||||||
|
DB_DSN ?= $(HOME)/.nex/config.db
|
||||||
|
|
||||||
|
ifeq ($(DB_DRIVER),mysql)
|
||||||
|
GOOSE_DIR := migrations/mysql
|
||||||
|
GOOSE_DRIVER := mysql
|
||||||
|
else ifeq ($(DB_DRIVER),sqlite3)
|
||||||
|
GOOSE_DIR := migrations/sqlite
|
||||||
|
GOOSE_DRIVER := sqlite3
|
||||||
|
else
|
||||||
|
$(error unsupported DB_DRIVER '$(DB_DRIVER)', use sqlite3 or mysql)
|
||||||
|
endif
|
||||||
|
|
||||||
|
build:
|
||||||
|
go build -ldflags "$(GO_LDFLAGS)" -o bin/server ./cmd/server
|
||||||
|
|
||||||
|
run:
|
||||||
|
go run -ldflags "$(GO_LDFLAGS)" ./cmd/server
|
||||||
|
|
||||||
|
test:
|
||||||
|
go test ./internal/... ./pkg/... ./tests/... ./cmd/server/... -v
|
||||||
|
|
||||||
|
test-unit:
|
||||||
|
go test ./internal/... ./pkg/... -v
|
||||||
|
|
||||||
|
test-integration:
|
||||||
|
go test ./tests/... -v
|
||||||
|
|
||||||
|
test-coverage:
|
||||||
|
go test ./... -coverprofile=coverage.out
|
||||||
|
go tool cover -html=coverage.out -o coverage.html
|
||||||
|
@printf 'Coverage report generated: backend/coverage.html\n'
|
||||||
|
|
||||||
|
lint:
|
||||||
|
go tool golangci-lint run ./...
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -rf bin/ coverage.out coverage.html
|
||||||
|
|
||||||
|
migrate-up:
|
||||||
|
@printf 'Running database migration up...\n'
|
||||||
|
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" up
|
||||||
|
|
||||||
|
migrate-down:
|
||||||
|
@printf 'Running database migration down...\n'
|
||||||
|
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" down
|
||||||
|
|
||||||
|
migrate-status:
|
||||||
|
@printf 'Checking database migration status...\n'
|
||||||
|
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" status
|
||||||
|
|
||||||
|
migrate-create:
|
||||||
|
@printf 'Migration name: '; \
|
||||||
|
read name; \
|
||||||
|
goose -dir migrations/sqlite create $$name sql; \
|
||||||
|
goose -dir migrations/mysql create $$name sql
|
||||||
|
|
||||||
|
mysql-up:
|
||||||
|
@printf 'Starting MySQL test container...\n'
|
||||||
|
cd tests/mysql && docker-compose up -d
|
||||||
|
@printf 'Waiting for MySQL to be ready...\n'
|
||||||
|
@for i in $$(seq 1 30); do \
|
||||||
|
if docker exec nex-mysql-test mysqladmin ping -h localhost -u root -ptestpass --silent 2>/dev/null; then \
|
||||||
|
printf 'MySQL is ready\n'; \
|
||||||
|
exit 0; \
|
||||||
|
fi; \
|
||||||
|
printf 'Waiting... (%s/30)\n' $$i; \
|
||||||
|
sleep 1; \
|
||||||
|
done; \
|
||||||
|
printf 'MySQL failed to start\n'; \
|
||||||
|
exit 1
|
||||||
|
|
||||||
|
mysql-down:
|
||||||
|
@printf 'Stopping MySQL test container...\n'
|
||||||
|
cd tests/mysql && docker-compose down -v
|
||||||
|
|
||||||
|
mysql-test:
|
||||||
|
@set -e; \
|
||||||
|
$(MAKE) mysql-up; \
|
||||||
|
trap '$(MAKE) mysql-down' EXIT; \
|
||||||
|
go test -tags=mysql ./tests/mysql/... -v -count=1
|
||||||
|
|
||||||
|
mysql-test-quick:
|
||||||
|
@printf 'Running MySQL tests without container management...\n'
|
||||||
|
go test -tags=mysql ./tests/mysql/... -v -count=1
|
||||||
@@ -4,29 +4,75 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。
|
|||||||
|
|
||||||
## 功能特性
|
## 功能特性
|
||||||
|
|
||||||
- 支持 OpenAI 协议(`/openai/v1/...`)
|
- 支持 OpenAI 协议(`/openai/v1/...`,例如 `/openai/v1/chat/completions`)
|
||||||
- 支持 Anthropic 协议(`/anthropic/v1/...`)
|
- 支持 Anthropic 协议(`/anthropic/v1/...`)
|
||||||
- 支持 Hub-and-Spoke 跨协议双向转换(OpenAI ↔ Anthropic)
|
- 支持 Hub-and-Spoke 跨协议双向转换(OpenAI ↔ Anthropic)
|
||||||
- 同协议透传(零语义损失、零序列化开销)
|
- 同协议透传(跳过 Canonical 全量转换,保持协议语义)
|
||||||
- 支持流式响应(SSE)
|
- 支持流式响应(SSE)
|
||||||
- 支持 Function Calling / Tools
|
- 支持 Function Calling / Tools
|
||||||
- 支持 Thinking / Reasoning
|
- 支持 Thinking / Reasoning
|
||||||
- 支持扩展层接口(Models、Embeddings、Rerank)
|
- 支持扩展层接口(Models、Embeddings、Rerank)
|
||||||
- 多供应商配置和路由
|
- 多供应商配置和路由
|
||||||
- 用量统计
|
- 用量统计
|
||||||
- 结构化日志(zap + lumberjack)
|
- 结构化日志(zap + lumberjack + 模块标识)
|
||||||
- YAML 配置管理
|
- YAML 配置管理
|
||||||
- 请求验证
|
- 请求验证
|
||||||
- 中间件支持(请求 ID、日志、恢复、CORS)
|
- 中间件支持(请求 ID、日志、恢复、CORS)
|
||||||
|
|
||||||
|
## 日志规范
|
||||||
|
|
||||||
|
### 模块标识
|
||||||
|
|
||||||
|
每个模块通过依赖注入获取带模块标识的 logger:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func NewProxyHandler(..., logger *zap.Logger) *ProxyHandler {
|
||||||
|
return &ProxyHandler{
|
||||||
|
logger: pkglogger.WithModule(logger, "handler.proxy"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
输出格式:
|
||||||
|
- Console: `INFO [handler.proxy] 处理请求 method=POST path=/v1/chat`
|
||||||
|
- JSON: `{"level":"info","logger":"handler.proxy","msg":"处理请求"}`
|
||||||
|
|
||||||
|
### 模块命名规范
|
||||||
|
|
||||||
|
| 模块 | 命名 |
|
||||||
|
|------|------|
|
||||||
|
| ProxyHandler | `handler.proxy` |
|
||||||
|
| ProviderHandler | `handler.provider` |
|
||||||
|
| Provider Client | `provider.client` |
|
||||||
|
| ConversionEngine | `conversion.engine` |
|
||||||
|
| RoutingCache | `service.routing_cache` |
|
||||||
|
| StatsBuffer | `service.stats_buffer` |
|
||||||
|
| Database | `database` |
|
||||||
|
|
||||||
|
### 标准字段
|
||||||
|
|
||||||
|
使用 `pkg/logger/field.go` 中定义的字段构造函数:
|
||||||
|
|
||||||
|
```go
|
||||||
|
logger.Debug("请求开始",
|
||||||
|
pkglogger.Method("POST"),
|
||||||
|
pkglogger.Path("/v1/chat"),
|
||||||
|
pkglogger.RequestID("xxx"),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### GORM 日志
|
||||||
|
|
||||||
|
GORM 日志自动桥接到 zap,SQL 查询映射到 Debug 级别。
|
||||||
|
|
||||||
## 技术栈
|
## 技术栈
|
||||||
|
|
||||||
- **语言**: Go 1.26+
|
- **语言**: Go 1.26+
|
||||||
- **HTTP 框架**: Gin
|
- **HTTP 框架**: Gin
|
||||||
- **ORM**: GORM
|
- **ORM**: GORM
|
||||||
- **数据库**: SQLite
|
- **数据库**: SQLite / MySQL
|
||||||
- **日志**: zap + lumberjack
|
- **日志**: zap + lumberjack
|
||||||
- **配置**: Viper + pflag(多层配置:CLI > 环境变量 > 配置文件 > 默认值)
|
- **配置**: Viper + pflag(Server 多层配置,Desktop 配置文件快照)
|
||||||
- **验证**: go-playground/validator/v10
|
- **验证**: go-playground/validator/v10
|
||||||
- **迁移**: goose
|
- **迁移**: goose
|
||||||
|
|
||||||
@@ -105,15 +151,23 @@ backend/
|
|||||||
│ │ ├── errors.go
|
│ │ ├── errors.go
|
||||||
│ │ └── wrap.go
|
│ │ └── wrap.go
|
||||||
│ ├── logger/ # 日志系统
|
│ ├── logger/ # 日志系统
|
||||||
│ │ ├── logger.go
|
│ │ ├── logger.go # 核心初始化
|
||||||
│ │ ├── rotate.go
|
│ │ ├── field.go # 标准字段定义
|
||||||
│ │ └── context.go
|
│ │ ├── module.go # 模块日志器
|
||||||
|
│ │ ├── context.go # Context 辅助函数
|
||||||
|
│ │ ├── gorm.go # GORM 适配器
|
||||||
|
│ │ ├── minimal.go # 最小化 logger
|
||||||
|
│ │ └── rotate.go # 日志轮转
|
||||||
│ ├── modelid/ # 统一模型 ID 工具包
|
│ ├── modelid/ # 统一模型 ID 工具包
|
||||||
│ │ ├── model_id.go
|
│ │ ├── model_id.go
|
||||||
│ │ └── model_id_test.go
|
│ │ └── model_id_test.go
|
||||||
│ └── validator/ # 验证器
|
│ └── validator/ # 验证器
|
||||||
│ └── validator.go
|
│ └── validator.go
|
||||||
├── migrations/ # 数据库迁移
|
├── migrations/ # 数据库迁移
|
||||||
|
│ ├── embed.go # go:embed 迁移资源入口
|
||||||
|
│ ├── sqlite/
|
||||||
|
│ │ └── 20260421000001_initial_schema.sql
|
||||||
|
│ └── mysql/
|
||||||
│ └── 20260421000001_initial_schema.sql
|
│ └── 20260421000001_initial_schema.sql
|
||||||
├── tests/ # 集成测试
|
├── tests/ # 集成测试
|
||||||
│ ├── helpers.go # 测试辅助函数
|
│ ├── helpers.go # 测试辅助函数
|
||||||
@@ -170,7 +224,7 @@ OpenAI Response ← Canonical Response ← Anthropic Response
|
|||||||
|
|
||||||
### Smart Passthrough 机制
|
### Smart Passthrough 机制
|
||||||
|
|
||||||
同协议请求走 Smart Passthrough 路径,**零序列化开销**:
|
同协议请求走 Smart Passthrough 路径,不进入 Canonical 全量转换:
|
||||||
|
|
||||||
```
|
```
|
||||||
1. 检测 clientProtocol == providerProtocol
|
1. 检测 clientProtocol == providerProtocol
|
||||||
@@ -179,12 +233,14 @@ OpenAI Response ← Canonical Response ← Anthropic Response
|
|||||||
4. 响应中仅改写 model 字段:upstream_model_name → unified_id
|
4. 响应中仅改写 model 字段:upstream_model_name → unified_id
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Smart Passthrough 保持未改写 JSON 字段的内容和类型不变,不承诺保留原始字节顺序、空白或对象字段顺序。
|
||||||
|
|
||||||
### 流式转换器层次
|
### 流式转换器层次
|
||||||
|
|
||||||
```
|
```
|
||||||
StreamConverter (接口)
|
StreamConverter (接口)
|
||||||
├── PassthroughStreamConverter # 直接透传,无任何处理
|
├── PassthroughStreamConverter # 直接透传,无任何处理
|
||||||
├── SmartPassthroughStreamConverter # 同协议 + 逐 chunk 改写 model
|
├── SmartPassthroughStreamConverter # 同协议 + 按 SSE frame 改写 data JSON model
|
||||||
└── CanonicalStreamConverter # 跨协议完整转换(decode → encode)
|
└── CanonicalStreamConverter # 跨协议完整转换(decode → encode)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -251,6 +307,7 @@ StreamConverter (接口)
|
|||||||
| `PROTOCOL_CONSTRAINT_VIOLATION` | 协议约束违反 |
|
| `PROTOCOL_CONSTRAINT_VIOLATION` | 协议约束违反 |
|
||||||
| `ENCODING_FAILURE` | 编码失败 |
|
| `ENCODING_FAILURE` | 编码失败 |
|
||||||
| `INTERFACE_NOT_SUPPORTED` | 接口不支持(如 Anthropic Embeddings) |
|
| `INTERFACE_NOT_SUPPORTED` | 接口不支持(如 Anthropic Embeddings) |
|
||||||
|
| `UNSUPPORTED_MULTIMODAL` | 跨协议暂不支持多模态内容 |
|
||||||
|
|
||||||
### AppError 预定义错误
|
### AppError 预定义错误
|
||||||
|
|
||||||
@@ -277,15 +334,18 @@ go mod download
|
|||||||
go run cmd/server/main.go
|
go run cmd/server/main.go
|
||||||
```
|
```
|
||||||
|
|
||||||
服务将在端口 9826 启动。首次启动会自动创建配置文件和运行数据库迁移。
|
服务将在端口 9826 启动。首次启动会自动运行数据库迁移。
|
||||||
|
|
||||||
## 配置
|
## 配置
|
||||||
|
|
||||||
配置支持多种方式:配置文件、环境变量、命令行参数,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值**
|
配置方式取决于启动入口:
|
||||||
|
|
||||||
|
- **Server 入口**(`cmd/server`):支持 CLI 参数 > 环境变量 > 配置文件 > 默认值
|
||||||
|
- **Desktop 入口**(`cmd/desktop`):仅支持 `~/.nex/config.yaml` > 默认值,不支持 CLI 参数和 `NEX_*` 环境变量覆盖,修改配置文件后需重启生效
|
||||||
|
|
||||||
### 配置文件
|
### 配置文件
|
||||||
|
|
||||||
配置文件位于 `~/.nex/config.yaml`,首次启动自动生成。
|
配置文件位于 `~/.nex/config.yaml`。配置文件不存在时使用默认值,不会自动生成;需要自定义时手动创建该文件:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
server:
|
server:
|
||||||
@@ -294,10 +354,17 @@ server:
|
|||||||
write_timeout: 30s
|
write_timeout: 30s
|
||||||
|
|
||||||
database:
|
database:
|
||||||
path: ~/.nex/config.db
|
driver: sqlite # sqlite 或 mysql
|
||||||
|
path: ~/.nex/config.db # SQLite 数据库文件路径
|
||||||
|
# --- MySQL 配置(driver=mysql 时生效)---
|
||||||
|
# host: localhost
|
||||||
|
# port: 3306
|
||||||
|
# user: nex
|
||||||
|
# password: ""
|
||||||
|
# dbname: nex
|
||||||
max_idle_conns: 10
|
max_idle_conns: 10
|
||||||
max_open_conns: 100
|
max_open_conns: 100
|
||||||
conn_max_lifetime: 1h
|
conn_max_lifetime: 1h0m0s
|
||||||
|
|
||||||
log:
|
log:
|
||||||
level: info
|
level: info
|
||||||
@@ -308,19 +375,27 @@ log:
|
|||||||
compress: true
|
compress: true
|
||||||
```
|
```
|
||||||
|
|
||||||
### 环境变量
|
### 环境变量(仅 Server 入口)
|
||||||
|
|
||||||
所有配置项都支持环境变量,使用 `NEX_` 前缀:
|
Server 入口下,所有配置项都支持环境变量,使用 `NEX_` 前缀:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export NEX_SERVER_PORT=9000
|
export NEX_SERVER_PORT=9000
|
||||||
export NEX_DATABASE_PATH=/data/nex.db
|
export NEX_DATABASE_PATH=/data/nex.db
|
||||||
export NEX_LOG_LEVEL=debug
|
export NEX_LOG_LEVEL=debug
|
||||||
|
|
||||||
|
# MySQL 模式
|
||||||
|
export NEX_DATABASE_DRIVER=mysql
|
||||||
|
export NEX_DATABASE_HOST=db.example.com
|
||||||
|
export NEX_DATABASE_PORT=3306
|
||||||
|
export NEX_DATABASE_USER=nex
|
||||||
|
export NEX_DATABASE_PASSWORD=secret
|
||||||
|
export NEX_DATABASE_DBNAME=nex
|
||||||
```
|
```
|
||||||
|
|
||||||
命名规则:配置路径转大写 + 下划线 + `NEX_` 前缀(如 `server.port` → `NEX_SERVER_PORT`)。
|
命名规则:配置路径转大写 + 下划线 + `NEX_` 前缀(如 `server.port` → `NEX_SERVER_PORT`)。
|
||||||
|
|
||||||
### 命令行参数
|
### 命令行参数(仅 Server 入口)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./server --server-port 9000 --log-level debug --database-path /tmp/test.db
|
./server --server-port 9000 --log-level debug --database-path /tmp/test.db
|
||||||
@@ -332,7 +407,7 @@ export NEX_LOG_LEVEL=debug
|
|||||||
|
|
||||||
```
|
```
|
||||||
服务器: --server-port, --server-read-timeout, --server-write-timeout
|
服务器: --server-port, --server-read-timeout, --server-write-timeout
|
||||||
数据库: --database-path, --database-max-idle-conns, --database-max-open-conns, --database-conn-max-lifetime
|
数据库: --database-driver, --database-path, --database-host, --database-port, --database-user, --database-password, --database-dbname, --database-max-idle-conns, --database-max-open-conns, --database-conn-max-lifetime
|
||||||
日志: --log-level, --log-path, --log-max-size, --log-max-backups, --log-max-age, --log-compress
|
日志: --log-level, --log-path, --log-max-size, --log-max-backups, --log-max-age, --log-compress
|
||||||
通用: --config (指定配置文件路径)
|
通用: --config (指定配置文件路径)
|
||||||
```
|
```
|
||||||
@@ -352,36 +427,56 @@ export NEX_LOG_LEVEL=debug
|
|||||||
# Docker 部署
|
# Docker 部署
|
||||||
docker run -d -e NEX_SERVER_PORT=9000 -e NEX_LOG_LEVEL=info nex-server
|
docker run -d -e NEX_SERVER_PORT=9000 -e NEX_LOG_LEVEL=info nex-server
|
||||||
|
|
||||||
|
# MySQL 模式
|
||||||
|
./server --database-driver mysql --database-host db.example.com --database-user nex --database-password secret --database-dbname nex
|
||||||
|
|
||||||
# 自定义配置文件
|
# 自定义配置文件
|
||||||
./server --config /path/to/custom.yaml
|
./server --config /path/to/custom.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
数据文件:
|
数据文件:
|
||||||
- `~/.nex/config.yaml` - 配置文件
|
- `~/.nex/config.yaml` - 配置文件
|
||||||
- `~/.nex/config.db` - SQLite 数据库
|
- `~/.nex/config.db` - SQLite 数据库(MySQL 模式下不使用本地数据库文件)
|
||||||
- `~/.nex/log/` - 日志目录
|
- `~/.nex/log/` - 日志目录
|
||||||
|
|
||||||
|
**MySQL 连接说明**:MySQL 连接使用 DSN 格式: `user:password@tcp(host:port)/dbname?charset=utf8mb4&parseTime=true&loc=Local`,最低支持 MySQL 8.0+。
|
||||||
|
|
||||||
## 测试
|
## 测试
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 运行所有测试
|
# 运行 backend 默认测试
|
||||||
make test
|
make test
|
||||||
|
|
||||||
|
# 分类测试
|
||||||
|
make test-unit
|
||||||
|
make test-integration
|
||||||
|
|
||||||
# 生成覆盖率报告
|
# 生成覆盖率报告
|
||||||
make test-coverage
|
make test-coverage
|
||||||
|
|
||||||
|
# MySQL 专项测试
|
||||||
|
make mysql-up
|
||||||
|
make mysql-down
|
||||||
|
make mysql-test
|
||||||
|
make mysql-test-quick
|
||||||
```
|
```
|
||||||
|
|
||||||
## 数据库迁移
|
## 数据库迁移
|
||||||
|
|
||||||
|
应用启动时使用随二进制打包的迁移资源(`go:embed`)自动执行迁移,server 和 desktop 发布产物均自包含,不依赖源码目录。开发时可继续通过 Makefile goose CLI 操作文件系统中的 `migrations/<dialect>/` 目录,运行时嵌入资源与文件系统目录共享同一批 SQL 文件。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 使用 Makefile
|
# 使用 Makefile
|
||||||
make migrate-up DB_PATH=~/.nex/config.db
|
make migrate-up DB_DSN=~/.nex/config.db
|
||||||
make migrate-down DB_PATH=~/.nex/config.db
|
make migrate-down DB_DSN=~/.nex/config.db
|
||||||
make migrate-status DB_PATH=~/.nex/config.db
|
make migrate-status DB_DSN=~/.nex/config.db
|
||||||
|
|
||||||
# 创建新迁移
|
# 创建新迁移
|
||||||
make migrate-create
|
make migrate-create
|
||||||
|
|
||||||
|
# MySQL 迁移
|
||||||
|
make migrate-up DB_DRIVER=mysql DB_DSN='user:pass@tcp(localhost:3306)/nex'
|
||||||
|
|
||||||
# 或直接使用 goose
|
# 或直接使用 goose
|
||||||
goose -dir migrations sqlite3 ~/.nex/config.db up
|
goose -dir migrations sqlite3 ~/.nex/config.db up
|
||||||
```
|
```
|
||||||
@@ -390,15 +485,15 @@ goose -dir migrations sqlite3 ~/.nex/config.db up
|
|||||||
|
|
||||||
### 代理接口
|
### 代理接口
|
||||||
|
|
||||||
使用 `/{protocol}/v1/{path}` URL 前缀路由:
|
使用 `/{protocol}/*path` URL 前缀路由。网关只剥离第一段协议前缀,不在 Handler 中统一添加或移除 `/v1`;剩余 path 是协议原生 nativePath,由对应 adapter 识别和组合上游 URL。
|
||||||
|
|
||||||
#### OpenAI 协议
|
#### OpenAI 协议
|
||||||
|
|
||||||
```
|
```
|
||||||
POST /openai/chat/completions
|
POST /openai/v1/chat/completions
|
||||||
GET /openai/models
|
GET /openai/v1/models
|
||||||
POST /openai/embeddings
|
POST /openai/v1/embeddings
|
||||||
POST /openai/rerank
|
POST /openai/v1/rerank
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Anthropic 协议
|
#### Anthropic 协议
|
||||||
@@ -408,10 +503,20 @@ POST /anthropic/v1/messages
|
|||||||
GET /anthropic/v1/models
|
GET /anthropic/v1/models
|
||||||
```
|
```
|
||||||
|
|
||||||
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销。
|
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传或 Smart Passthrough,跳过 Canonical 全量转换。
|
||||||
|
|
||||||
**统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。
|
**统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。
|
||||||
|
|
||||||
|
**base_url 约定**:
|
||||||
|
- OpenAI 供应商配置到版本路径一级,例如 `https://api.openai.com/v1`;当客户端请求 `/openai/v1/chat/completions` 时,OpenAI adapter 会把 nativePath `/v1/chat/completions` 映射为上游 path `/chat/completions`,最终 URL 为 `https://api.openai.com/v1/chat/completions`。
|
||||||
|
- Anthropic 供应商配置到域名级,例如 `https://api.anthropic.com`。
|
||||||
|
|
||||||
|
**模型提取边界**:只有 adapter 明确适配的 Chat、Embeddings、Rerank 等接口会提取 `model` 并尝试统一模型 ID 路由。未知接口不做顶层 `model` 猜测,直接按无 model 透传。
|
||||||
|
|
||||||
|
**流式透传边界**:同协议无响应 model 改写时 raw passthrough,保留 SSE frame 边界和 `[DONE]`;同协议需要改写时按 SSE frame 解析 `data` JSON,仅改写 `model`;跨协议继续使用 StreamDecoder → CanonicalStreamConverter → StreamEncoder。
|
||||||
|
|
||||||
|
**错误边界**:网关层代理错误返回 `{"error":"...","code":"..."}`。已收到上游 HTTP 响应时,非 2xx status、过滤 hop-by-hop header 后的 headers 和 body 直接透传;没有收到上游响应的连接/DNS/TLS/超时错误返回 `UPSTREAM_UNAVAILABLE`。
|
||||||
|
|
||||||
### 管理接口
|
### 管理接口
|
||||||
|
|
||||||
#### 供应商管理
|
#### 供应商管理
|
||||||
@@ -439,7 +544,7 @@ GET /anthropic/v1/models
|
|||||||
- Anthropic 协议:配置到域名,不包含版本路径,如 `https://api.anthropic.com`
|
- Anthropic 协议:配置到域名,不包含版本路径,如 `https://api.anthropic.com`
|
||||||
|
|
||||||
**对外 URL 格式**:
|
**对外 URL 格式**:
|
||||||
- OpenAI 协议:`/{protocol}/{endpoint}`,如 `/openai/chat/completions`、`/openai/models`、`/openai/embeddings`
|
- OpenAI 协议:`/{protocol}/v1/{endpoint}`,如 `/openai/v1/chat/completions`、`/openai/v1/models`、`/openai/v1/embeddings`
|
||||||
- Anthropic 协议:`/{protocol}/v1/{endpoint}`,如 `/anthropic/v1/messages`、`/anthropic/v1/models`
|
- Anthropic 协议:`/{protocol}/v1/{endpoint}`,如 `/anthropic/v1/messages`、`/anthropic/v1/models`
|
||||||
|
|
||||||
#### 模型管理
|
#### 模型管理
|
||||||
@@ -481,6 +586,20 @@ GET /anthropic/v1/models
|
|||||||
|
|
||||||
查询参数:`provider_id`、`model_name`、`start_date`(YYYY-MM-DD)、`end_date`、`group_by`(provider/model/date)
|
查询参数:`provider_id`、`model_name`、`start_date`(YYYY-MM-DD)、`end_date`、`group_by`(provider/model/date)
|
||||||
|
|
||||||
|
#### 版本信息
|
||||||
|
|
||||||
|
- `GET /api/version` - 获取后端构建版本信息
|
||||||
|
|
||||||
|
响应字段来源于构建阶段注入的 `buildinfo` 元数据:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"version": "0.1.0",
|
||||||
|
"commit": "abc1234",
|
||||||
|
"build_time": "2026-05-05T00:00:00Z"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
#### 健康检查
|
#### 健康检查
|
||||||
|
|
||||||
- `GET /health` - 返回 `{"status": "ok"}`
|
- `GET /health` - 返回 `{"status": "ok"}`
|
||||||
@@ -488,9 +607,12 @@ GET /anthropic/v1/models
|
|||||||
## 开发
|
## 开发
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
make build # 构建
|
make build # 构建 backend/bin/server
|
||||||
|
make run # 运行后端服务
|
||||||
make lint # 代码检查
|
make lint # 代码检查
|
||||||
make deps # 整理依赖
|
make clean # 清理 backend 构建产物
|
||||||
|
go mod tidy # 整理依赖
|
||||||
|
go generate ./... # 刷新 mock 等生成代码
|
||||||
```
|
```
|
||||||
|
|
||||||
环境要求:Go 1.26 或更高版本
|
环境要求:Go 1.26 或更高版本
|
||||||
@@ -539,6 +661,7 @@ err := v.Validate(myStruct)
|
|||||||
|
|
||||||
- **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节
|
- **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节
|
||||||
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接
|
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接
|
||||||
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配
|
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配(lint 强约束:errorlint 禁止 `err ==` 直接比较和 `err.(*T)` 直接断言)
|
||||||
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
|
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
|
||||||
|
- **日志 error 字段**:使用 `zap.Error(err)`,不使用 `zap.String("error", err.Error())` 手工字符串化
|
||||||
- **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片
|
- **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片
|
||||||
|
|||||||
43
backend/cmd/desktop/dialog_darwin.go
Normal file
43
backend/cmd/desktop/dialog_darwin.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
//go:build darwin
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func platformStartupChannels(runner commandRunner) []promptChannel {
|
||||||
|
return []promptChannel{
|
||||||
|
{
|
||||||
|
name: "macos-notification",
|
||||||
|
available: func() error {
|
||||||
|
_, err := runner.LookPath("osascript")
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
run: func(req promptRequest) error {
|
||||||
|
script := fmt.Sprintf(`display notification "%s" with title "%s" subtitle "%s"`,
|
||||||
|
escapeAppleScript(req.message), escapeAppleScript(req.title), escapeAppleScript(req.subtitle))
|
||||||
|
return runner.Run(promptCommandTimeout, nil, "osascript", "-e", script)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "macos-alert",
|
||||||
|
available: func() error {
|
||||||
|
_, err := runner.LookPath("osascript")
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
run: func(req promptRequest) error {
|
||||||
|
script := fmt.Sprintf(`display alert "%s" message "%s" as critical buttons {"OK"} default button "OK"`,
|
||||||
|
escapeAppleScript(req.title), escapeAppleScript(req.message))
|
||||||
|
return runner.Run(promptCommandTimeout, nil, "osascript", "-e", script)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func escapeAppleScript(s string) string {
|
||||||
|
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||||
|
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||||
|
return s
|
||||||
|
}
|
||||||
46
backend/cmd/desktop/dialog_darwin_test.go
Normal file
46
backend/cmd/desktop/dialog_darwin_test.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
//go:build darwin
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDarwinStartupChannelsBuildNotificationAndAlert(t *testing.T) {
|
||||||
|
runner := &fakeCommandRunner{paths: map[string]bool{"osascript": true}}
|
||||||
|
channels := platformStartupChannels(runner)
|
||||||
|
if len(channels) != 2 {
|
||||||
|
t.Fatalf("macOS 应有 notification 和 alert 两级通道,实际: %d", len(channels))
|
||||||
|
}
|
||||||
|
|
||||||
|
req := promptRequest{title: "Nex 启动失败", subtitle: "config", message: "路径 C:\\tmp 包含 \"quote\""}
|
||||||
|
for _, channel := range channels {
|
||||||
|
if err := channel.available(); err != nil {
|
||||||
|
t.Fatalf("通道 %s 应可用: %v", channel.name, err)
|
||||||
|
}
|
||||||
|
if err := channel.run(req); err != nil {
|
||||||
|
t.Fatalf("通道 %s 执行失败: %v", channel.name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(runner.calls) != 2 {
|
||||||
|
t.Fatalf("应执行两次 osascript,实际: %d", len(runner.calls))
|
||||||
|
}
|
||||||
|
if runner.calls[0].name != "osascript" || runner.calls[0].args[0] != "-e" {
|
||||||
|
t.Fatalf("notification 命令参数错误: %#v", runner.calls[0])
|
||||||
|
}
|
||||||
|
if script := runner.calls[0].args[1]; !strings.Contains(script, "display notification") || !strings.Contains(script, `\\tmp`) || !strings.Contains(script, `\"quote\"`) {
|
||||||
|
t.Fatalf("notification AppleScript 未正确构造或转义: %s", script)
|
||||||
|
}
|
||||||
|
if script := runner.calls[1].args[1]; !strings.Contains(script, "display alert") || !strings.Contains(script, "as critical") {
|
||||||
|
t.Fatalf("alert AppleScript 未使用 critical 告警: %s", script)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEscapeAppleScript(t *testing.T) {
|
||||||
|
got := escapeAppleScript(`C:\tmp "quote"`)
|
||||||
|
if !strings.Contains(got, `C:\\tmp`) || !strings.Contains(got, `\"quote\"`) {
|
||||||
|
t.Fatalf("AppleScript 转义结果错误: %s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
111
backend/cmd/desktop/dialog_linux.go
Normal file
111
backend/cmd/desktop/dialog_linux.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dialogToolType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
toolNone dialogToolType = iota
|
||||||
|
toolNotifySend
|
||||||
|
toolKdialogPassive
|
||||||
|
toolZenity
|
||||||
|
toolKdialogError
|
||||||
|
toolXmessage
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
dialogTools map[string]bool
|
||||||
|
dialogToolOnce sync.Once
|
||||||
|
dialogToolNames = []string{"notify-send", "kdialog", "zenity", "xmessage"}
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
dialogToolOnce.Do(func() { detectDialogTools(defaultCommandRunner{}) })
|
||||||
|
}
|
||||||
|
|
||||||
|
func platformStartupChannels(runner commandRunner) []promptChannel {
|
||||||
|
return []promptChannel{
|
||||||
|
linuxCommandChannel("notify-send", toolNotifySend, runner, linuxHasGraphicalSessionAndDBus, func(req promptRequest) []string {
|
||||||
|
return []string{"-u", "critical", "-a", appName, "-i", "nex", req.title, req.message}
|
||||||
|
}),
|
||||||
|
linuxCommandChannel("kdialog", toolKdialogPassive, runner, linuxHasGraphicalSession, func(req promptRequest) []string {
|
||||||
|
return []string{"--title", req.title, "--passivepopup", req.message, "10"}
|
||||||
|
}),
|
||||||
|
linuxCommandChannel("zenity", toolZenity, runner, linuxHasGraphicalSession, func(req promptRequest) []string {
|
||||||
|
return []string{"--error", fmt.Sprintf("--title=%s", req.title), fmt.Sprintf("--text=%s", req.message)}
|
||||||
|
}),
|
||||||
|
linuxCommandChannel("kdialog", toolKdialogError, runner, linuxHasGraphicalSession, func(req promptRequest) []string {
|
||||||
|
return []string{"--title", req.title, "--error", req.message}
|
||||||
|
}),
|
||||||
|
linuxCommandChannel("xmessage", toolXmessage, runner, linuxHasX11Display, func(req promptRequest) []string {
|
||||||
|
return []string{"-center", "-buttons", "OK:0", "-default", "OK", fmt.Sprintf("%s: %s", req.title, req.message)}
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func detectDialogTools(runner commandRunner) {
|
||||||
|
dialogTools = make(map[string]bool, len(dialogToolNames))
|
||||||
|
for _, name := range dialogToolNames {
|
||||||
|
_, err := runner.LookPath(name)
|
||||||
|
dialogTools[name] = err == nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func linuxCommandChannel(name string, typ dialogToolType, runner commandRunner, environmentOK func() error, args func(promptRequest) []string) promptChannel {
|
||||||
|
return promptChannel{
|
||||||
|
name: fmt.Sprintf("linux-%s-%d", name, typ),
|
||||||
|
available: func() error {
|
||||||
|
if err := linuxCommandAvailable(runner, name); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return environmentOK()
|
||||||
|
},
|
||||||
|
run: func(req promptRequest) error {
|
||||||
|
return runner.Run(promptCommandTimeout, nil, name, args(req)...)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func linuxCommandAvailable(runner commandRunner, name string) error {
|
||||||
|
if _, ok := runner.(defaultCommandRunner); ok {
|
||||||
|
dialogToolOnce.Do(func() { detectDialogTools(runner) })
|
||||||
|
if dialogTools[name] {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%s 不可用", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := runner.LookPath(name)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func linuxHasGraphicalSession() error {
|
||||||
|
if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" {
|
||||||
|
return errors.New("缺少图形会话")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func linuxHasGraphicalSessionAndDBus() error {
|
||||||
|
if err := linuxHasGraphicalSession(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if os.Getenv("DBUS_SESSION_BUS_ADDRESS") == "" {
|
||||||
|
return errors.New("缺少 DBus session bus")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func linuxHasX11Display() error {
|
||||||
|
if os.Getenv("DISPLAY") == "" {
|
||||||
|
return errors.New("缺少 X11 DISPLAY")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
61
backend/cmd/desktop/dialog_linux_test.go
Normal file
61
backend/cmd/desktop/dialog_linux_test.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestLinuxStartupChannelsPriorityAndArguments(t *testing.T) {
|
||||||
|
t.Setenv("DISPLAY", ":0")
|
||||||
|
t.Setenv("DBUS_SESSION_BUS_ADDRESS", "unix:path=/tmp/dbus")
|
||||||
|
runner := &fakeCommandRunner{paths: map[string]bool{
|
||||||
|
"notify-send": true,
|
||||||
|
"kdialog": true,
|
||||||
|
"zenity": true,
|
||||||
|
"xmessage": true,
|
||||||
|
}}
|
||||||
|
|
||||||
|
channels := platformStartupChannels(runner)
|
||||||
|
if len(channels) != 5 {
|
||||||
|
t.Fatalf("Linux 应有 5 个 UI 通道,实际: %d", len(channels))
|
||||||
|
}
|
||||||
|
|
||||||
|
req := promptRequest{title: "Nex 启动失败", message: "端口被占用"}
|
||||||
|
for _, channel := range channels {
|
||||||
|
if err := channel.available(); err != nil {
|
||||||
|
t.Fatalf("通道 %s 应可用: %v", channel.name, err)
|
||||||
|
}
|
||||||
|
if err := channel.run(req); err != nil {
|
||||||
|
t.Fatalf("通道 %s 执行失败: %v", channel.name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wantNames := []string{"notify-send", "kdialog", "zenity", "kdialog", "xmessage"}
|
||||||
|
for i, want := range wantNames {
|
||||||
|
if got := runner.calls[i].name; got != want {
|
||||||
|
t.Fatalf("第 %d 个命令 = %s, want %s", i, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if got := runner.calls[0].args; len(got) < 2 || got[0] != "-u" || got[1] != "critical" {
|
||||||
|
t.Fatalf("notify-send 应使用 critical 参数,实际: %#v", got)
|
||||||
|
}
|
||||||
|
if got := runner.calls[1].args; len(got) < 3 || got[2] != "--passivepopup" {
|
||||||
|
t.Fatalf("kdialog 第一跳应使用 passivepopup,实际: %#v", got)
|
||||||
|
}
|
||||||
|
if got := runner.calls[2].args; len(got) < 1 || got[0] != "--error" {
|
||||||
|
t.Fatalf("zenity 应使用 --error,实际: %#v", got)
|
||||||
|
}
|
||||||
|
if got := runner.calls[4].args; len(got) < 1 || got[0] != "-center" {
|
||||||
|
t.Fatalf("xmessage 应居中显示,实际: %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinuxNotifySendRequiresDBus(t *testing.T) {
|
||||||
|
t.Setenv("DISPLAY", ":0")
|
||||||
|
t.Setenv("DBUS_SESSION_BUS_ADDRESS", "")
|
||||||
|
runner := &fakeCommandRunner{paths: map[string]bool{"notify-send": true}}
|
||||||
|
|
||||||
|
channels := platformStartupChannels(runner)
|
||||||
|
if err := channels[0].available(); err == nil {
|
||||||
|
t.Fatal("notify-send 缺少 DBus session bus 时应不可用")
|
||||||
|
}
|
||||||
|
}
|
||||||
15
backend/cmd/desktop/dialog_logger.go
Normal file
15
backend/cmd/desktop/dialog_logger.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
pkgLogger "nex/backend/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func dialogLogger() *zap.Logger {
|
||||||
|
if zapLogger != nil {
|
||||||
|
return zapLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
return pkgLogger.NewMinimal()
|
||||||
|
}
|
||||||
133
backend/cmd/desktop/dialog_windows.go
Normal file
133
backend/cmd/desktop/dialog_windows.go
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"syscall"
|
||||||
|
"unicode/utf16"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
mbOK = 0x00000000
|
||||||
|
mbIconError = 0x10
|
||||||
|
mbIconInformation = 0x40
|
||||||
|
mbTaskModal = 0x00002000
|
||||||
|
mbSetForeground = 0x00010000
|
||||||
|
mbTopMost = 0x00040000
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
user32 = syscall.NewLazyDLL("user32.dll")
|
||||||
|
procMessageBoxW = user32.NewProc("MessageBoxW")
|
||||||
|
callMessageBoxW = func(hwnd, text, caption, flags uintptr) (uintptr, error) {
|
||||||
|
ret, _, err := procMessageBoxW.Call(hwnd, text, caption, flags)
|
||||||
|
return ret, err
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func platformStartupChannels(runner commandRunner) []promptChannel {
|
||||||
|
return []promptChannel{
|
||||||
|
{
|
||||||
|
name: "windows-toast",
|
||||||
|
available: func() error {
|
||||||
|
_, err := findPowerShell(runner)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
run: func(req promptRequest) error {
|
||||||
|
name, err := findPowerShell(runner)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return runner.Run(promptCommandTimeout, []string{
|
||||||
|
"NEX_TOAST_TITLE=" + req.title,
|
||||||
|
"NEX_TOAST_BODY=" + req.message,
|
||||||
|
}, name, "-NoProfile", "-NonInteractive", "-ExecutionPolicy", "Bypass", "-EncodedCommand", encodePowerShellCommand(windowsToastScript()))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "windows-messagebox",
|
||||||
|
available: func() error {
|
||||||
|
return messageBoxAvailable()
|
||||||
|
},
|
||||||
|
run: func(req promptRequest) error {
|
||||||
|
return messageBox(req.title, req.message, messageBoxStartupFlags())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func findPowerShell(runner commandRunner) (string, error) {
|
||||||
|
for _, name := range []string{"powershell.exe", "powershell"} {
|
||||||
|
if _, err := runner.LookPath(name); err == nil {
|
||||||
|
return name, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("PowerShell 不可用")
|
||||||
|
}
|
||||||
|
|
||||||
|
func windowsToastScript() string {
|
||||||
|
return `$ErrorActionPreference = 'Stop'
|
||||||
|
Add-Type -AssemblyName System.Runtime.WindowsRuntime
|
||||||
|
$template = [Windows.UI.Notifications.ToastTemplateType]::ToastText02
|
||||||
|
$xml = [Windows.UI.Notifications.ToastNotificationManager]::GetTemplateContent($template)
|
||||||
|
$texts = $xml.GetElementsByTagName('text')
|
||||||
|
$texts.Item(0).AppendChild($xml.CreateTextNode($env:NEX_TOAST_TITLE)) | Out-Null
|
||||||
|
$texts.Item(1).AppendChild($xml.CreateTextNode($env:NEX_TOAST_BODY)) | Out-Null
|
||||||
|
$toast = [Windows.UI.Notifications.ToastNotification]::new($xml)
|
||||||
|
[Windows.UI.Notifications.ToastNotificationManager]::CreateToastNotifier('Nex').Show($toast)`
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodePowerShellCommand(script string) string {
|
||||||
|
encoded := utf16.Encode([]rune(script))
|
||||||
|
buf := make([]byte, 0, len(encoded)*2)
|
||||||
|
for _, value := range encoded {
|
||||||
|
buf = append(buf, byte(value), byte(value>>8))
|
||||||
|
}
|
||||||
|
return base64.StdEncoding.EncodeToString(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func messageBoxAvailable() error {
|
||||||
|
if _, err := syscall.UTF16PtrFromString("Nex"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := syscall.UTF16PtrFromString("test"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return procMessageBoxW.Find()
|
||||||
|
}
|
||||||
|
|
||||||
|
func messageBoxStartupFlags() uint {
|
||||||
|
return mbOK | mbIconError | mbTaskModal | mbSetForeground | mbTopMost
|
||||||
|
}
|
||||||
|
|
||||||
|
func messageBox(title, message string, flags uint) error {
|
||||||
|
titlePtr, err := syscall.UTF16PtrFromString(title)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
messagePtr, err := syscall.UTF16PtrFromString(message)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ret, callErr := callMessageBoxW(
|
||||||
|
0,
|
||||||
|
uintptr(unsafe.Pointer(messagePtr)),
|
||||||
|
uintptr(unsafe.Pointer(titlePtr)),
|
||||||
|
uintptr(flags),
|
||||||
|
)
|
||||||
|
if ret != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if callErr != nil && !errors.Is(callErr, syscall.Errno(0)) {
|
||||||
|
return callErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("MessageBoxW 调用失败")
|
||||||
|
}
|
||||||
1
backend/cmd/desktop/icon_windows.rc
Normal file
1
backend/cmd/desktop/icon_windows.rc
Normal file
@@ -0,0 +1 @@
|
|||||||
|
1 ICON "../../../assets/icon.ico"
|
||||||
@@ -2,9 +2,9 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -12,31 +12,28 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"nex/embedfs"
|
||||||
"github.com/getlantern/systray"
|
|
||||||
"github.com/gofrs/flock"
|
|
||||||
"github.com/pressly/goose/v3"
|
|
||||||
"go.uber.org/zap"
|
|
||||||
"gorm.io/driver/sqlite"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"gorm.io/gorm/logger"
|
|
||||||
|
|
||||||
"nex/backend/internal/config"
|
"nex/backend/internal/config"
|
||||||
"nex/backend/internal/conversion"
|
"nex/backend/internal/conversion"
|
||||||
"nex/backend/internal/conversion/anthropic"
|
"nex/backend/internal/conversion/anthropic"
|
||||||
"nex/backend/internal/conversion/openai"
|
"nex/backend/internal/conversion/openai"
|
||||||
|
"nex/backend/internal/database"
|
||||||
"nex/backend/internal/handler"
|
"nex/backend/internal/handler"
|
||||||
"nex/backend/internal/handler/middleware"
|
"nex/backend/internal/handler/middleware"
|
||||||
"nex/backend/internal/provider"
|
"nex/backend/internal/provider"
|
||||||
"nex/backend/internal/repository"
|
"nex/backend/internal/repository"
|
||||||
"nex/backend/internal/service"
|
"nex/backend/internal/service"
|
||||||
pkgLogger "nex/backend/pkg/logger"
|
"nex/backend/pkg/buildinfo"
|
||||||
|
|
||||||
"nex/embedfs"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gofrs/flock"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
pkgLogger "nex/backend/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -44,30 +41,79 @@ var (
|
|||||||
zapLogger *zap.Logger
|
zapLogger *zap.Logger
|
||||||
shutdownCtx context.Context
|
shutdownCtx context.Context
|
||||||
shutdownCancel context.CancelFunc
|
shutdownCancel context.CancelFunc
|
||||||
|
desktopHooks = defaultDesktopRuntimeHooks()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type singletonLocker interface {
|
||||||
|
Lock() error
|
||||||
|
Unlock() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type desktopRuntimeHooks struct {
|
||||||
|
loadConfig func() (*config.Config, config.ConfigMetadata, error)
|
||||||
|
newLock func(string) singletonLocker
|
||||||
|
listen func(int) (net.Listener, error)
|
||||||
|
upgradeLogger func(*zap.Logger, pkgLogger.Config) (*zap.Logger, error)
|
||||||
|
initDB func(*config.DatabaseConfig, *zap.Logger) (*gorm.DB, error)
|
||||||
|
closeDB func(*gorm.DB)
|
||||||
|
registerAdapters func(conversion.AdapterRegistry) error
|
||||||
|
setupStaticFiles func(*gin.Engine) error
|
||||||
|
startServer func(*http.Server, net.Listener, chan<- error, *zap.Logger)
|
||||||
|
setupSystray func(int, <-chan error) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultDesktopRuntimeHooks() desktopRuntimeHooks {
|
||||||
|
return desktopRuntimeHooks{
|
||||||
|
loadConfig: config.LoadDesktopConfigWithMetadata,
|
||||||
|
newLock: func(lockPath string) singletonLocker { return NewSingletonLock(lockPath) },
|
||||||
|
listen: listenDesktopPort,
|
||||||
|
upgradeLogger: pkgLogger.Upgrade,
|
||||||
|
initDB: database.Init,
|
||||||
|
closeDB: database.Close,
|
||||||
|
registerAdapters: registerDesktopAdapters,
|
||||||
|
setupStaticFiles: setupStaticFiles,
|
||||||
|
startServer: startDesktopServer,
|
||||||
|
setupSystray: setupSystray,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
port := 9826
|
minimalLogger := pkgLogger.NewMinimal()
|
||||||
|
if err := runDesktop(minimalLogger); err != nil {
|
||||||
singleLock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway.lock"))
|
reportStartupFailure(err, dialogLogger())
|
||||||
if err := singleLock.Lock(); err != nil {
|
|
||||||
showError("Nex Gateway", "已有 Nex 实例运行")
|
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
defer singleLock.Unlock()
|
|
||||||
|
|
||||||
if err := checkPortAvailable(port); err != nil {
|
|
||||||
showError("Nex Gateway", err.Error())
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg, err := config.LoadConfig()
|
func runDesktop(minimalLogger *zap.Logger) error {
|
||||||
|
if minimalLogger == nil {
|
||||||
|
minimalLogger = pkgLogger.NewMinimal()
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, cfgMeta, err := desktopHooks.loadConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
showError("Nex Gateway", fmt.Sprintf("加载配置失败: %v", err))
|
return newStartupError(phaseConfig, desktopConfigErrorMessage(getDesktopConfigPath(), err), err)
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
zapLogger, err = pkgLogger.New(pkgLogger.Config{
|
port := cfg.Server.Port
|
||||||
|
|
||||||
|
singleLock := desktopHooks.newLock(filepath.Join(os.TempDir(), "nex-gateway.lock"))
|
||||||
|
if err := singleLock.Lock(); err != nil {
|
||||||
|
return newStartupError(phaseSingleton, "已有 Nex 实例运行", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := singleLock.Unlock(); err != nil {
|
||||||
|
minimalLogger.Warn("释放实例锁失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
listener, err := desktopHooks.listen(port)
|
||||||
|
if err != nil {
|
||||||
|
return newStartupError(phasePort, desktopPortUnavailableMessage(port), err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
zapLogger, err = desktopHooks.upgradeLogger(minimalLogger, pkgLogger.Config{
|
||||||
Level: cfg.Log.Level,
|
Level: cfg.Log.Level,
|
||||||
Path: cfg.Log.Path,
|
Path: cfg.Log.Path,
|
||||||
MaxSize: cfg.Log.MaxSize,
|
MaxSize: cfg.Log.MaxSize,
|
||||||
@@ -76,17 +122,27 @@ func main() {
|
|||||||
Compress: cfg.Log.Compress,
|
Compress: cfg.Log.Compress,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
showError("Nex Gateway", fmt.Sprintf("初始化日志失败: %v", err))
|
return newStartupError(phaseLogger, fmt.Sprintf("初始化日志失败\n\n日志目录: %s\n\n请检查目录权限或磁盘空间", cfg.Log.Path), err)
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
defer zapLogger.Sync()
|
defer func() {
|
||||||
|
if err := zapLogger.Sync(); err != nil {
|
||||||
|
minimalLogger.Warn("同步日志失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
db, err := initDatabase(cfg)
|
cfg.PrintSummary(zapLogger)
|
||||||
|
|
||||||
|
db, err := desktopHooks.initDB(&cfg.Database, zapLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
showError("Nex Gateway", fmt.Sprintf("初始化数据库失败: %v", err))
|
phase := phaseDatabase
|
||||||
os.Exit(1)
|
message := fmt.Sprintf("数据库初始化失败\n\n请检查数据库配置、文件权限或连接状态\n\n%v", err)
|
||||||
|
if errors.Is(err, database.ErrMigration) {
|
||||||
|
phase = phaseMigration
|
||||||
|
message = fmt.Sprintf("数据库迁移失败\n\n请查看日志或检查数据库迁移权限\n\n%v", err)
|
||||||
}
|
}
|
||||||
defer closeDB(db)
|
return newStartupError(phase, message, err)
|
||||||
|
}
|
||||||
|
defer desktopHooks.closeDB(db)
|
||||||
|
|
||||||
providerRepo := repository.NewProviderRepository(db)
|
providerRepo := repository.NewProviderRepository(db)
|
||||||
modelRepo := repository.NewModelRepository(db)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
@@ -109,20 +165,19 @@ func main() {
|
|||||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||||
|
|
||||||
registry := conversion.NewMemoryRegistry()
|
registry := conversion.NewMemoryRegistry()
|
||||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
if err := desktopHooks.registerAdapters(registry); err != nil {
|
||||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
|
return newStartupError(phaseAdapter, startupInternalErrorMessage(), err)
|
||||||
}
|
|
||||||
if err := registry.Register(anthropic.NewAdapter()); err != nil {
|
|
||||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error()))
|
|
||||||
}
|
}
|
||||||
engine := conversion.NewConversionEngine(registry, zapLogger)
|
engine := conversion.NewConversionEngine(registry, zapLogger)
|
||||||
|
|
||||||
providerClient := provider.NewClient()
|
providerClient := provider.NewClient(zapLogger)
|
||||||
|
|
||||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
|
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
|
||||||
providerHandler := handler.NewProviderHandler(providerService)
|
providerHandler := handler.NewProviderHandler(providerService)
|
||||||
modelHandler := handler.NewModelHandler(modelService)
|
modelHandler := handler.NewModelHandler(modelService)
|
||||||
statsHandler := handler.NewStatsHandler(statsService)
|
statsHandler := handler.NewStatsHandler(statsService)
|
||||||
|
versionHandler := handler.NewVersionHandler()
|
||||||
|
settingsHandler := handler.NewSettingsHandler(cfg, "desktop", true, cfgMeta.ConfigPath)
|
||||||
|
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
r := gin.New()
|
r := gin.New()
|
||||||
@@ -132,107 +187,65 @@ func main() {
|
|||||||
r.Use(middleware.Logging(zapLogger))
|
r.Use(middleware.Logging(zapLogger))
|
||||||
r.Use(middleware.CORS())
|
r.Use(middleware.CORS())
|
||||||
|
|
||||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
|
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler, versionHandler, settingsHandler)
|
||||||
setupStaticFiles(r)
|
if err := desktopHooks.setupStaticFiles(r); err != nil {
|
||||||
|
return newStartupError(phaseStaticResource, startupInternalErrorMessage(), err)
|
||||||
|
}
|
||||||
|
|
||||||
server = &http.Server{
|
server = &http.Server{
|
||||||
Addr: fmt.Sprintf(":%d", port),
|
Addr: desktopListenAddr(port),
|
||||||
Handler: r,
|
Handler: r,
|
||||||
ReadTimeout: cfg.Server.ReadTimeout,
|
ReadTimeout: cfg.Server.ReadTimeout,
|
||||||
WriteTimeout: cfg.Server.WriteTimeout,
|
WriteTimeout: cfg.Server.WriteTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
shutdownCtx, shutdownCancel = context.WithCancel(context.Background())
|
shutdownCtx, shutdownCancel = context.WithCancel(context.Background())
|
||||||
|
defer doShutdown()
|
||||||
|
|
||||||
go func() {
|
serverErrCh := make(chan error, 1)
|
||||||
zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr))
|
desktopHooks.startServer(server, listener, serverErrCh, zapLogger)
|
||||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
select {
|
||||||
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error()))
|
case err := <-serverErrCh:
|
||||||
}
|
return newStartupError(phaseServer, startupServerErrorMessage(), err)
|
||||||
}()
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
|
||||||
go func() {
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
|
|
||||||
zapLogger.Warn("无法打开浏览器", zap.String("error", err.Error()))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
setupSystray(port)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func initDatabase(cfg *config.Config) (*gorm.DB, error) {
|
if err := desktopHooks.setupSystray(port, serverErrCh); err != nil {
|
||||||
dbDir := filepath.Dir(cfg.Database.Path)
|
|
||||||
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
|
||||||
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err := gorm.Open(sqlite.Open(cfg.Database.Path), &gorm.Config{
|
|
||||||
Logger: logger.Default.LogMode(logger.Info),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := runMigrations(db); err != nil {
|
|
||||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
|
||||||
log.Printf("警告: 启用 WAL 模式失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sqlDB, err := db.DB()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns)
|
|
||||||
sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns)
|
|
||||||
sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime)
|
|
||||||
|
|
||||||
return db, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func runMigrations(db *gorm.DB) error {
|
|
||||||
sqlDB, err := db.DB()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
migrationsDir := getMigrationsDir()
|
|
||||||
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
|
|
||||||
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
|
|
||||||
}
|
|
||||||
|
|
||||||
goose.SetDialect("sqlite3")
|
|
||||||
if err := goose.Up(sqlDB, migrationsDir); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-serverErrCh:
|
||||||
|
return newStartupError(phaseServer, startupServerErrorMessage(), err)
|
||||||
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMigrationsDir() string {
|
|
||||||
_, filename, _, ok := runtime.Caller(0)
|
|
||||||
if ok {
|
|
||||||
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations")
|
|
||||||
if abs, err := filepath.Abs(dir); err == nil {
|
|
||||||
return abs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "./migrations"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func closeDB(db *gorm.DB) {
|
func registerDesktopAdapters(registry conversion.AdapterRegistry) error {
|
||||||
sqlDB, err := db.DB()
|
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
sqlDB.Close()
|
return registry.Register(anthropic.NewAdapter())
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
|
func startDesktopServer(server *http.Server, listener net.Listener, serverErrCh chan<- error, logger *zap.Logger) {
|
||||||
r.Any("/v1/*path", proxyHandler.HandleProxy)
|
go func() {
|
||||||
|
logger.Info("AI Gateway 启动",
|
||||||
|
zap.String("addr", server.Addr),
|
||||||
|
zap.String("version", buildinfo.Version()),
|
||||||
|
zap.String("commit", buildinfo.Commit()),
|
||||||
|
zap.String("build_time", buildinfo.BuildTime()))
|
||||||
|
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||||
|
serverErrCh <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler, versionHandler *handler.VersionHandler, settingsHandler *handler.SettingsHandler) {
|
||||||
|
r.Any("/openai/*path", withProtocol("openai", proxyHandler.HandleProxy))
|
||||||
|
r.Any("/anthropic/*path", withProtocol("anthropic", proxyHandler.HandleProxy))
|
||||||
|
r.GET("/api/version", versionHandler.GetVersion)
|
||||||
|
|
||||||
providers := r.Group("/api/providers")
|
providers := r.Group("/api/providers")
|
||||||
{
|
{
|
||||||
@@ -258,17 +271,38 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
|
|||||||
stats.GET("/aggregate", statsHandler.AggregateStats)
|
stats.GET("/aggregate", statsHandler.AggregateStats)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
settings := r.Group("/api/settings")
|
||||||
|
{
|
||||||
|
settings.GET("/startup", settingsHandler.GetStartupSettings)
|
||||||
|
settings.PUT("/startup", settingsHandler.SaveStartupSettings)
|
||||||
|
}
|
||||||
|
|
||||||
r.GET("/health", func(c *gin.Context) {
|
r.GET("/health", func(c *gin.Context) {
|
||||||
c.JSON(200, gin.H{"status": "ok"})
|
c.JSON(200, gin.H{"status": "ok"})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupStaticFiles(r *gin.Engine) {
|
func withProtocol(protocol string, next gin.HandlerFunc) gin.HandlerFunc {
|
||||||
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
return func(c *gin.Context) {
|
||||||
if err != nil {
|
c.Params = append(c.Params, gin.Param{Key: "protocol", Value: protocol})
|
||||||
zapLogger.Fatal("无法加载前端资源", zap.String("error", err.Error()))
|
next(c)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setupStaticFiles(r *gin.Engine) error {
|
||||||
|
distFS, err := frontendDistFS()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
setupStaticFilesWithFS(r, distFS)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func frontendDistFS() (fs.FS, error) {
|
||||||
|
return fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupStaticFilesWithFS(r *gin.Engine, distFS fs.FS) {
|
||||||
getContentType := func(path string) string {
|
getContentType := func(path string) string {
|
||||||
if strings.HasSuffix(path, ".js") {
|
if strings.HasSuffix(path, ".js") {
|
||||||
return "application/javascript"
|
return "application/javascript"
|
||||||
@@ -301,20 +335,23 @@ func setupStaticFiles(r *gin.Engine) {
|
|||||||
c.Data(200, getContentType(filepath), data)
|
c.Data(200, getContentType(filepath), data)
|
||||||
})
|
})
|
||||||
|
|
||||||
r.GET("/favicon.svg", func(c *gin.Context) {
|
r.GET("/icon.png", func(c *gin.Context) {
|
||||||
data, err := fs.ReadFile(distFS, "favicon.svg")
|
data, err := fs.ReadFile(distFS, "icon.png")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Status(404)
|
c.Status(404)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Data(200, "image/svg+xml", data)
|
c.Data(200, "image/png", data)
|
||||||
})
|
})
|
||||||
|
|
||||||
r.NoRoute(func(c *gin.Context) {
|
r.NoRoute(func(c *gin.Context) {
|
||||||
path := c.Request.URL.Path
|
path := c.Request.URL.Path
|
||||||
|
|
||||||
if strings.HasPrefix(path, "/api/") ||
|
if strings.HasPrefix(path, "/api/") ||
|
||||||
strings.HasPrefix(path, "/v1/") ||
|
strings.HasPrefix(path, "/openai/") ||
|
||||||
|
strings.HasPrefix(path, "/anthropic/") ||
|
||||||
|
path == "/openai" ||
|
||||||
|
path == "/anthropic" ||
|
||||||
strings.HasPrefix(path, "/health") {
|
strings.HasPrefix(path, "/health") {
|
||||||
c.JSON(404, gin.H{"error": "not found"})
|
c.JSON(404, gin.H{"error": "not found"})
|
||||||
return
|
return
|
||||||
@@ -329,50 +366,6 @@ func setupStaticFiles(r *gin.Engine) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupSystray(port int) {
|
|
||||||
systray.Run(func() {
|
|
||||||
var icon []byte
|
|
||||||
var err error
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
icon, err = embedfs.Assets.ReadFile("assets/icon.ico")
|
|
||||||
} else {
|
|
||||||
icon, err = embedfs.Assets.ReadFile("assets/icon.png")
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
zapLogger.Error("无法加载托盘图标", zap.String("error", err.Error()))
|
|
||||||
}
|
|
||||||
systray.SetIcon(icon)
|
|
||||||
systray.SetTitle("Nex Gateway")
|
|
||||||
systray.SetTooltip("AI Gateway")
|
|
||||||
|
|
||||||
mOpen := systray.AddMenuItem("打开管理界面", "在浏览器中打开")
|
|
||||||
systray.AddSeparator()
|
|
||||||
mStatus := systray.AddMenuItem("状态: 运行中", "")
|
|
||||||
mStatus.Disable()
|
|
||||||
mPort := systray.AddMenuItem(fmt.Sprintf("端口: %d", port), "")
|
|
||||||
mPort.Disable()
|
|
||||||
systray.AddSeparator()
|
|
||||||
mAbout := systray.AddMenuItem("关于", "")
|
|
||||||
systray.AddSeparator()
|
|
||||||
mQuit := systray.AddMenuItem("退出", "停止服务并退出")
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-mOpen.ClickedCh:
|
|
||||||
openBrowser(fmt.Sprintf("http://localhost:%d", port))
|
|
||||||
case <-mAbout.ClickedCh:
|
|
||||||
showAbout()
|
|
||||||
case <-mQuit.ClickedCh:
|
|
||||||
doShutdown()
|
|
||||||
systray.Quit()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func doShutdown() {
|
func doShutdown() {
|
||||||
if zapLogger != nil {
|
if zapLogger != nil {
|
||||||
zapLogger.Info("正在关闭服务器...")
|
zapLogger.Info("正在关闭服务器...")
|
||||||
@@ -381,7 +374,9 @@ func doShutdown() {
|
|||||||
if server != nil {
|
if server != nil {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
server.Shutdown(ctx)
|
if err := server.Shutdown(ctx); err != nil && zapLogger != nil {
|
||||||
|
zapLogger.Warn("关闭服务器失败", zap.Error(err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if shutdownCancel != nil {
|
if shutdownCancel != nil {
|
||||||
@@ -389,13 +384,36 @@ func doShutdown() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkPortAvailable(port int) error {
|
func getDesktopConfigPath() string {
|
||||||
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
configDir, err := config.GetConfigDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("端口 %d 已被占用\n\n可能原因:\n- 已有 Nex 实例运行\n- 其他程序占用了该端口\n\n请检查并关闭占用端口的程序", port)
|
return "~/.nex/config.yaml"
|
||||||
}
|
}
|
||||||
ln.Close()
|
return filepath.Join(configDir, "config.yaml")
|
||||||
return nil
|
}
|
||||||
|
|
||||||
|
func desktopConfigErrorMessage(configPath string, err error) string {
|
||||||
|
return fmt.Sprintf("加载配置失败\n\n配置文件: %s\n\n%v", configPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func desktopListenAddr(port int) string {
|
||||||
|
return fmt.Sprintf(":%d", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func desktopURL(port int) string {
|
||||||
|
return fmt.Sprintf("http://localhost:%d", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func desktopPortMenuTitle(port int) string {
|
||||||
|
return fmt.Sprintf("端口: %d", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func listenDesktopPort(port int) (net.Listener, error) {
|
||||||
|
return net.Listen("tcp", desktopListenAddr(port))
|
||||||
|
}
|
||||||
|
|
||||||
|
func desktopPortUnavailableMessage(port int) string {
|
||||||
|
return fmt.Sprintf("端口 %d 已被占用\n\n可能原因:\n- 已有 Nex 实例运行\n- 其他程序占用了该端口\n\n请检查并关闭占用端口的程序", port)
|
||||||
}
|
}
|
||||||
|
|
||||||
type SingletonLock struct {
|
type SingletonLock struct {
|
||||||
@@ -419,8 +437,8 @@ func (s *SingletonLock) Lock() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SingletonLock) Unlock() {
|
func (s *SingletonLock) Unlock() error {
|
||||||
s.flock.Unlock()
|
return s.flock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func openBrowser(url string) error {
|
func openBrowser(url string) error {
|
||||||
@@ -447,49 +465,3 @@ func openBrowser(url string) error {
|
|||||||
|
|
||||||
return cmd.Start()
|
return cmd.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
func showError(title, message string) {
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "darwin":
|
|
||||||
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`, message, title)
|
|
||||||
exec.Command("osascript", "-e", script).Run()
|
|
||||||
case "windows":
|
|
||||||
messageBox(title, message, MB_ICONERROR)
|
|
||||||
case "linux":
|
|
||||||
exec.Command("zenity", "--error", fmt.Sprintf("--title=%s", title), fmt.Sprintf("--text=%s", message)).Run()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func showAbout() {
|
|
||||||
message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway"
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "darwin":
|
|
||||||
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "关于 Nex Gateway"`, message)
|
|
||||||
exec.Command("osascript", "-e", script).Run()
|
|
||||||
case "windows":
|
|
||||||
messageBox("关于 Nex Gateway", message, MB_ICONINFORMATION)
|
|
||||||
case "linux":
|
|
||||||
exec.Command("zenity", "--info", "--title=关于 Nex Gateway", fmt.Sprintf("--text=%s", message)).Run()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
MB_ICONERROR = 0x10
|
|
||||||
MB_ICONINFORMATION = 0x40
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
user32 = syscall.NewLazyDLL("user32.dll")
|
|
||||||
procMessageBoxW = user32.NewProc("MessageBoxW")
|
|
||||||
)
|
|
||||||
|
|
||||||
func messageBox(title, message string, flags uint) {
|
|
||||||
titlePtr, _ := syscall.UTF16PtrFromString(title)
|
|
||||||
messagePtr, _ := syscall.UTF16PtrFromString(message)
|
|
||||||
procMessageBoxW.Call(
|
|
||||||
0,
|
|
||||||
uintptr(unsafe.Pointer(messagePtr)),
|
|
||||||
uintptr(unsafe.Pointer(titlePtr)),
|
|
||||||
uintptr(flags),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,30 +1,106 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"runtime"
|
"errors"
|
||||||
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMessageBoxW_WindowsOnly(t *testing.T) {
|
func withMessageBoxW(t *testing.T, fn func(hwnd, text, caption, flags uintptr) (uintptr, error)) {
|
||||||
if runtime.GOOS != "windows" {
|
t.Helper()
|
||||||
t.Skip("MessageBoxW 仅在 Windows 上测试")
|
|
||||||
|
old := callMessageBoxW
|
||||||
|
callMessageBoxW = fn
|
||||||
|
t.Cleanup(func() {
|
||||||
|
callMessageBoxW = old
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
messageBox("测试标题", "测试消息", MB_ICONINFORMATION)
|
func TestMessageBoxW_WindowsOnly_InvalidUTF16(t *testing.T) {
|
||||||
|
err := messageBox("bad\x00title", "测试消息", mbIconInformation)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("包含 NUL 字符时应该返回错误")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageBoxW_WindowsOnly_SuccessIgnoresLastError(t *testing.T) {
|
||||||
|
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
|
||||||
|
return 1, syscall.Errno(123)
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := messageBox("测试标题", "测试消息", mbIconInformation); err != nil {
|
||||||
|
t.Fatalf("MessageBoxW 返回成功时应忽略 last error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageBoxW_WindowsOnly_FailureUsesReturnValue(t *testing.T) {
|
||||||
|
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
|
||||||
|
return 0, syscall.Errno(5)
|
||||||
|
})
|
||||||
|
|
||||||
|
err := messageBox("测试标题", "测试消息", mbIconInformation)
|
||||||
|
if !errors.Is(err, syscall.Errno(5)) {
|
||||||
|
t.Fatalf("MessageBoxW 返回 0 时应返回调用错误: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShowError_WindowsBranch(t *testing.T) {
|
func TestShowError_WindowsBranch(t *testing.T) {
|
||||||
if runtime.GOOS != "windows" {
|
old := buildPromptChannels
|
||||||
t.Skip("Windows 原生对话框测试仅在 Windows 上运行")
|
buildPromptChannels = func(commandRunner) []promptChannel {
|
||||||
|
return []promptChannel{{
|
||||||
|
name: "fake-failed-channel",
|
||||||
|
available: func() error { return nil },
|
||||||
|
run: func(promptRequest) error { return syscall.Errno(5) },
|
||||||
|
}}
|
||||||
}
|
}
|
||||||
|
t.Cleanup(func() { buildPromptChannels = old })
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if recovered := recover(); recovered != nil {
|
||||||
|
t.Fatalf("showError 不应因 MessageBoxW 失败而 panic: %v", recovered)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
showError("测试错误", "这是一条测试错误消息")
|
showError("测试错误", "这是一条测试错误消息")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShowAbout_WindowsBranch(t *testing.T) {
|
func TestMessageBoxW_WindowsOnly_StartupFlags(t *testing.T) {
|
||||||
if runtime.GOOS != "windows" {
|
var gotFlags uintptr
|
||||||
t.Skip("Windows 原生对话框测试仅在 Windows 上运行")
|
withMessageBoxW(t, func(_, _, _, flags uintptr) (uintptr, error) {
|
||||||
|
gotFlags = flags
|
||||||
|
return 1, syscall.Errno(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := messageBox("测试标题", "测试消息", messageBoxStartupFlags()); err != nil {
|
||||||
|
t.Fatalf("MessageBoxW 应成功: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
showAbout()
|
for _, flag := range []uint{mbIconError, mbTaskModal, mbSetForeground, mbTopMost} {
|
||||||
|
if gotFlags&uintptr(flag) == 0 {
|
||||||
|
t.Fatalf("startup flags 缺少 0x%x,实际: 0x%x", flag, gotFlags)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWindowsStartupChannelsUseToastBeforeMessageBox(t *testing.T) {
|
||||||
|
runner := &fakeCommandRunner{paths: map[string]bool{"powershell.exe": true}}
|
||||||
|
channels := platformStartupChannels(runner)
|
||||||
|
if len(channels) != 2 {
|
||||||
|
t.Fatalf("Windows 应有 Toast 和 MessageBox 两级通道,实际: %d", len(channels))
|
||||||
|
}
|
||||||
|
|
||||||
|
if channels[0].name != "windows-toast" || channels[1].name != "windows-messagebox" {
|
||||||
|
t.Fatalf("Windows 通道顺序错误: %s, %s", channels[0].name, channels[1].name)
|
||||||
|
}
|
||||||
|
if err := channels[0].available(); err != nil {
|
||||||
|
t.Fatalf("PowerShell 存在时 Toast 通道应可用: %v", err)
|
||||||
|
}
|
||||||
|
if err := channels[0].run(promptRequest{title: "Nex 启动失败", message: "端口被占用"}); err != nil {
|
||||||
|
t.Fatalf("Toast fake runner 应执行成功: %v", err)
|
||||||
|
}
|
||||||
|
if len(runner.calls) != 1 || runner.calls[0].name != "powershell.exe" {
|
||||||
|
t.Fatalf("Toast 应调用 powershell.exe,实际: %#v", runner.calls)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
9
backend/cmd/desktop/metadata.go
Normal file
9
backend/cmd/desktop/metadata.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
const (
|
||||||
|
appName = "Nex"
|
||||||
|
appTooltip = appName
|
||||||
|
appDescription = "AI Gateway - 统一的大模型 API 网关"
|
||||||
|
// #nosec G101 -- 项目官网地址不是凭据
|
||||||
|
appWebsite = "https://github.com/nex/gateway"
|
||||||
|
)
|
||||||
13
backend/cmd/desktop/metadata_test.go
Normal file
13
backend/cmd/desktop/metadata_test.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDesktopMetadata(t *testing.T) {
|
||||||
|
if appName != "Nex" {
|
||||||
|
t.Fatalf("appName = %q, want %q", appName, "Nex")
|
||||||
|
}
|
||||||
|
|
||||||
|
if appTooltip != appName {
|
||||||
|
t.Fatalf("appTooltip = %q, want %q", appTooltip, appName)
|
||||||
|
}
|
||||||
|
}
|
||||||
43
backend/cmd/desktop/migration_test.go
Normal file
43
backend/cmd/desktop/migration_test.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"nex/backend/internal/config"
|
||||||
|
"nex/backend/internal/database"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDesktop_InitMigrationsWithoutSourceTree(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
origDir, err := os.Getwd()
|
||||||
|
if err == nil {
|
||||||
|
defer func() {
|
||||||
|
if chdirErr := os.Chdir(origDir); chdirErr != nil {
|
||||||
|
t.Logf("无法恢复工作目录: %v", chdirErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
if chdirErr := os.Chdir(tmpDir); chdirErr != nil {
|
||||||
|
t.Skipf("无法切换到临时目录: %v", chdirErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.DatabaseConfig{
|
||||||
|
Driver: "sqlite",
|
||||||
|
Path: filepath.Join(tmpDir, "nex-test.db"),
|
||||||
|
MaxIdleConns: 5,
|
||||||
|
MaxOpenConns: 10,
|
||||||
|
ConnMaxLifetime: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
zapLogger := zap.NewNop()
|
||||||
|
db, err := database.Init(cfg, zapLogger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("在无源码目录环境下数据库初始化应成功,但返回错误: %v", err)
|
||||||
|
}
|
||||||
|
database.Close(db)
|
||||||
|
}
|
||||||
@@ -1,69 +1,69 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCheckPortAvailable(t *testing.T) {
|
func TestListenDesktopPortReturnsReusableListener(t *testing.T) {
|
||||||
port := 19826
|
listener, err := listenDesktopPort(0)
|
||||||
|
|
||||||
err := checkPortAvailable(port)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("端口 %d 应该可用: %v", port, err)
|
t.Fatalf("listener-first 应直接获取配置端口 listener: %v", err)
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("端口可用测试通过")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckPortOccupied(t *testing.T) {
|
|
||||||
port := 19827
|
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", ":19827")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("无法启动测试服务器: %v", err)
|
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
|
|
||||||
|
server := &http.Server{ReadHeaderTimeout: time.Second}
|
||||||
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
conn, err := listener.Accept()
|
defer close(done)
|
||||||
if err == nil {
|
err := server.Serve(listener)
|
||||||
conn.Close()
|
if err != nil && err != http.ErrServerClosed && !errors.Is(err, net.ErrClosed) {
|
||||||
|
t.Errorf("使用同一个 listener 启动 server 失败: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
if err := server.Close(); err != nil {
|
||||||
|
t.Fatalf("关闭测试 server 失败: %v", err)
|
||||||
err = checkPortAvailable(port)
|
}
|
||||||
if err == nil {
|
<-done
|
||||||
t.Fatal("端口被占用时应该返回错误")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("端口占用检测测试通过")
|
func TestGetDesktopConfigPath(t *testing.T) {
|
||||||
|
path := getDesktopConfigPath()
|
||||||
|
if path == "" {
|
||||||
|
t.Fatal("getDesktopConfigPath 应返回非空路径")
|
||||||
|
}
|
||||||
|
if !strings.Contains(path, "config.yaml") {
|
||||||
|
t.Fatalf("路径应包含 config.yaml,实际: %s", path)
|
||||||
|
}
|
||||||
|
t.Log("getDesktopConfigPath 测试通过")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCheckPortAvailableAfterClose(t *testing.T) {
|
func TestDesktopConfiguredPortHelpers(t *testing.T) {
|
||||||
port := 19828
|
port := 19830
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", ":19828")
|
if got := desktopListenAddr(port); got != ":19830" {
|
||||||
if err != nil {
|
t.Fatalf("HTTP 监听地址应使用配置端口,实际: %s", got)
|
||||||
t.Fatalf("无法启动测试服务器: %v", err)
|
}
|
||||||
|
if got := desktopURL(port); got != "http://localhost:19830" {
|
||||||
|
t.Fatalf("浏览器 URL 应使用配置端口,实际: %s", got)
|
||||||
|
}
|
||||||
|
if got := desktopPortMenuTitle(port); got != "端口: 19830" {
|
||||||
|
t.Fatalf("托盘端口显示应使用配置端口,实际: %s", got)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
server := &http.Server{}
|
func TestDesktopConfigErrorMessageContainsPathAndReason(t *testing.T) {
|
||||||
go server.Serve(listener)
|
msg := desktopConfigErrorMessage("/tmp/nex/config.yaml", errors.New("yaml parse failed"))
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
if !strings.Contains(msg, "/tmp/nex/config.yaml") {
|
||||||
|
t.Fatalf("配置错误提示应包含配置路径,实际: %s", msg)
|
||||||
listener.Close()
|
}
|
||||||
time.Sleep(100 * time.Millisecond)
|
if !strings.Contains(msg, "yaml parse failed") {
|
||||||
|
t.Fatalf("配置错误提示应包含失败原因,实际: %s", msg)
|
||||||
err = checkPortAvailable(port)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("端口关闭后应该可用: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("端口关闭后可用测试通过")
|
|
||||||
}
|
}
|
||||||
|
|||||||
121
backend/cmd/desktop/reporter.go
Normal file
121
backend/cmd/desktop/reporter.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const promptCommandTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
type promptRequest struct {
|
||||||
|
title string
|
||||||
|
message string
|
||||||
|
subtitle string
|
||||||
|
}
|
||||||
|
|
||||||
|
type promptChannel struct {
|
||||||
|
name string
|
||||||
|
available func() error
|
||||||
|
run func(promptRequest) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type commandRunner interface {
|
||||||
|
LookPath(file string) (string, error)
|
||||||
|
Run(timeout time.Duration, env []string, name string, args ...string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type defaultCommandRunner struct{}
|
||||||
|
|
||||||
|
var buildPromptChannels = platformStartupChannels
|
||||||
|
|
||||||
|
func (defaultCommandRunner) LookPath(file string) (string, error) {
|
||||||
|
return exec.LookPath(file)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (defaultCommandRunner) Run(timeout time.Duration, env []string, name string, args ...string) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, name, args...)
|
||||||
|
if len(env) > 0 {
|
||||||
|
cmd.Env = append(os.Environ(), env...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func showError(title, message string) {
|
||||||
|
reportPrompt(promptRequest{title: title, message: message}, os.Stderr, dialogLogger())
|
||||||
|
}
|
||||||
|
|
||||||
|
func reportStartupFailure(err error, logger *zap.Logger) {
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var startupErr *startupError
|
||||||
|
if !errors.As(err, &startupErr) {
|
||||||
|
startupErr = newStartupError(phaseServer, startupServerErrorMessage(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if logger == nil {
|
||||||
|
logger = dialogLogger()
|
||||||
|
}
|
||||||
|
logger.Error("desktop 启动失败",
|
||||||
|
zap.String("phase", startupErr.Phase()),
|
||||||
|
zap.Error(startupErr))
|
||||||
|
|
||||||
|
reportPrompt(promptRequest{
|
||||||
|
title: startupTitle(),
|
||||||
|
message: startupErr.UserMessage(),
|
||||||
|
subtitle: startupErr.Phase(),
|
||||||
|
}, os.Stderr, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func reportPrompt(req promptRequest, fallback io.Writer, logger *zap.Logger) {
|
||||||
|
runPromptPipeline(req, buildPromptChannels(defaultCommandRunner{}), fallback, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runPromptPipeline(req promptRequest, channels []promptChannel, fallback io.Writer, logger *zap.Logger) {
|
||||||
|
if logger == nil {
|
||||||
|
logger = dialogLogger()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, channel := range channels {
|
||||||
|
if channel.available != nil {
|
||||||
|
if err := channel.available(); err != nil {
|
||||||
|
logger.Warn("提示通道不可用", zap.String("channel", channel.name), zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := channel.run(req); err != nil {
|
||||||
|
logger.Warn("提示通道执行失败", zap.String("channel", channel.name), zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writePromptFallback(fallback, req.title, req.message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writePromptFallback(w io.Writer, title, message string) {
|
||||||
|
if w == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := io.WriteString(w, "错误: "+title+": "+message+"\n"); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
140
backend/cmd/desktop/reporter_test.go
Normal file
140
backend/cmd/desktop/reporter_test.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zaptest/observer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type commandCall struct {
|
||||||
|
timeout time.Duration
|
||||||
|
env []string
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeCommandRunner struct {
|
||||||
|
paths map[string]bool
|
||||||
|
runErrs map[string]error
|
||||||
|
calls []commandCall
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *fakeCommandRunner) LookPath(file string) (string, error) {
|
||||||
|
if r.paths[file] {
|
||||||
|
return "/usr/bin/" + file, nil
|
||||||
|
}
|
||||||
|
return "", exec.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *fakeCommandRunner) Run(timeout time.Duration, env []string, name string, args ...string) error {
|
||||||
|
r.calls = append(r.calls, commandCall{
|
||||||
|
timeout: timeout,
|
||||||
|
env: append([]string(nil), env...),
|
||||||
|
name: name,
|
||||||
|
args: append([]string(nil), args...),
|
||||||
|
})
|
||||||
|
if err := r.runErrs[name]; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunPromptPipelineFallbackOrder(t *testing.T) {
|
||||||
|
var calls []string
|
||||||
|
channels := []promptChannel{
|
||||||
|
{
|
||||||
|
name: "unavailable",
|
||||||
|
available: func() error {
|
||||||
|
calls = append(calls, "available-1")
|
||||||
|
return errors.New("missing")
|
||||||
|
},
|
||||||
|
run: func(promptRequest) error {
|
||||||
|
calls = append(calls, "run-1")
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failed",
|
||||||
|
available: func() error {
|
||||||
|
calls = append(calls, "available-2")
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
run: func(promptRequest) error {
|
||||||
|
calls = append(calls, "run-2")
|
||||||
|
return errors.New("failed")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
available: func() error {
|
||||||
|
calls = append(calls, "available-3")
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
run: func(promptRequest) error {
|
||||||
|
calls = append(calls, "run-3")
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var fallback bytes.Buffer
|
||||||
|
runPromptPipeline(promptRequest{title: "Nex 启动失败", message: "启动失败"}, channels, &fallback, zap.NewNop())
|
||||||
|
|
||||||
|
want := []string{"available-1", "available-2", "run-2", "available-3", "run-3"}
|
||||||
|
if fmt.Sprint(calls) != fmt.Sprint(want) {
|
||||||
|
t.Fatalf("调用顺序 = %v, want %v", calls, want)
|
||||||
|
}
|
||||||
|
if fallback.Len() != 0 {
|
||||||
|
t.Fatalf("成功通道后不应写入 fallback,实际: %s", fallback.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunPromptPipelineWritesFallback(t *testing.T) {
|
||||||
|
channels := []promptChannel{
|
||||||
|
{
|
||||||
|
name: "unavailable",
|
||||||
|
available: func() error { return errors.New("missing") },
|
||||||
|
run: func(promptRequest) error { return nil },
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var fallback bytes.Buffer
|
||||||
|
runPromptPipeline(promptRequest{title: "Nex 启动失败", message: "端口被占用"}, channels, &fallback, zap.NewNop())
|
||||||
|
|
||||||
|
want := "错误: Nex 启动失败: 端口被占用\n"
|
||||||
|
if fallback.String() != want {
|
||||||
|
t.Fatalf("fallback = %q, want %q", fallback.String(), want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReportStartupFailureLogsRedactedError(t *testing.T) {
|
||||||
|
old := buildPromptChannels
|
||||||
|
buildPromptChannels = func(commandRunner) []promptChannel {
|
||||||
|
return []promptChannel{{name: "fake-success", run: func(promptRequest) error { return nil }}}
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { buildPromptChannels = old })
|
||||||
|
|
||||||
|
core, logs := observer.New(zap.ErrorLevel)
|
||||||
|
logger := zap.New(core)
|
||||||
|
err := errors.New("数据库连接失败: nex:secret@tcp(localhost:3306)/nex password=secret api_key=sk-test")
|
||||||
|
|
||||||
|
reportStartupFailure(err, logger)
|
||||||
|
|
||||||
|
entries := logs.All()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("应记录 1 条错误日志,实际: %d", len(entries))
|
||||||
|
}
|
||||||
|
fields := fmt.Sprint(entries[0].ContextMap())
|
||||||
|
for _, secret := range []string{"secret", "sk-test"} {
|
||||||
|
if strings.Contains(fields, secret) {
|
||||||
|
t.Fatalf("启动失败日志不应包含敏感信息 %q,实际: %s", secret, fields)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
44
backend/cmd/desktop/routes_test.go
Normal file
44
backend/cmd/desktop/routes_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"testing/fstest"
|
||||||
|
|
||||||
|
"nex/backend/internal/handler"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetupRoutes_VersionDoesNotFallback(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
setupRoutes(r, &handler.ProxyHandler{}, &handler.ProviderHandler{}, &handler.ModelHandler{}, &handler.StatsHandler{}, handler.NewVersionHandler(), handler.NewSettingsHandler(nil, "desktop", true, ""))
|
||||||
|
setupStaticFilesWithFS(r, fstest.MapFS{
|
||||||
|
"index.html": {Data: []byte("<html>fallback</html>")},
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if contentType := w.Header().Get("Content-Type"); contentType == "text/html; charset=utf-8" {
|
||||||
|
t.Fatalf("版本接口不应返回 SPA fallback HTML")
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]string
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||||
|
t.Fatalf("解析响应失败: %v", err)
|
||||||
|
}
|
||||||
|
for _, key := range []string{"version", "commit", "build_time"} {
|
||||||
|
if result[key] == "" {
|
||||||
|
t.Fatalf("响应缺少 %s 字段: %#v", key, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
332
backend/cmd/desktop/run_desktop_test.go
Normal file
332
backend/cmd/desktop/run_desktop_test.go
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"nex/backend/internal/config"
|
||||||
|
"nex/backend/internal/conversion"
|
||||||
|
"nex/backend/internal/database"
|
||||||
|
pkgLogger "nex/backend/pkg/logger"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeDesktopLock struct {
|
||||||
|
lockErr error
|
||||||
|
unlockCount atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *fakeDesktopLock) Lock() error {
|
||||||
|
return l.lockErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *fakeDesktopLock) Unlock() error {
|
||||||
|
l.unlockCount.Add(1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *fakeDesktopLock) unlocked() bool {
|
||||||
|
return l.unlockCount.Load() > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
type recordingListener struct {
|
||||||
|
net.Listener
|
||||||
|
closeCount atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRecordingListener(t *testing.T) *recordingListener {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("创建测试 listener 失败: %v", err)
|
||||||
|
}
|
||||||
|
return &recordingListener{Listener: listener}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *recordingListener) Close() error {
|
||||||
|
l.closeCount.Add(1)
|
||||||
|
return l.Listener.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *recordingListener) closed() bool {
|
||||||
|
return l.closeCount.Load() > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func testDesktopConfig(t *testing.T) *config.Config {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
cfg := config.DefaultConfig()
|
||||||
|
cfg.Server.Port = 0
|
||||||
|
cfg.Database.Driver = "sqlite"
|
||||||
|
cfg.Database.Path = filepath.Join(tmpDir, "config.db")
|
||||||
|
cfg.Log.Path = filepath.Join(tmpDir, "log")
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func installDesktopTestHooks(t *testing.T, cfg *config.Config, mutate func(*desktopRuntimeHooks)) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
oldHooks := desktopHooks
|
||||||
|
oldServer := server
|
||||||
|
oldLogger := zapLogger
|
||||||
|
oldShutdownCtx := shutdownCtx
|
||||||
|
oldShutdownCancel := shutdownCancel
|
||||||
|
|
||||||
|
server = nil
|
||||||
|
zapLogger = nil
|
||||||
|
shutdownCtx = nil
|
||||||
|
shutdownCancel = nil
|
||||||
|
|
||||||
|
hooks := defaultDesktopRuntimeHooks()
|
||||||
|
if cfg != nil {
|
||||||
|
hooks.loadConfig = func() (*config.Config, config.ConfigMetadata, error) {
|
||||||
|
return cfg, config.ConfigMetadata{ConfigPath: filepath.Join(t.TempDir(), "config.yaml")}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hooks.upgradeLogger = func(_ *zap.Logger, _ pkgLogger.Config) (*zap.Logger, error) {
|
||||||
|
return zap.NewNop(), nil
|
||||||
|
}
|
||||||
|
hooks.setupStaticFiles = func(*gin.Engine) error { return nil }
|
||||||
|
hooks.startServer = func(*http.Server, net.Listener, chan<- error, *zap.Logger) {}
|
||||||
|
hooks.setupSystray = func(int, <-chan error) error { return nil }
|
||||||
|
|
||||||
|
if mutate != nil {
|
||||||
|
mutate(&hooks)
|
||||||
|
}
|
||||||
|
desktopHooks = hooks
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if server != nil {
|
||||||
|
_ = server.Close()
|
||||||
|
}
|
||||||
|
desktopHooks = oldHooks
|
||||||
|
server = oldServer
|
||||||
|
zapLogger = oldLogger
|
||||||
|
shutdownCtx = oldShutdownCtx
|
||||||
|
shutdownCancel = oldShutdownCancel
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func requireStartupPhase(t *testing.T, err error, want startupPhase) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("期望 %s 阶段启动错误,实际 nil", want)
|
||||||
|
}
|
||||||
|
var startupErr *startupError
|
||||||
|
if !errors.As(err, &startupErr) {
|
||||||
|
t.Fatalf("期望 startupError,实际: %T %v", err, err)
|
||||||
|
}
|
||||||
|
if startupErr.phase != want {
|
||||||
|
t.Fatalf("phase = %s, want %s", startupErr.phase, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunDesktopConfigFailureReturnsConfigPhase(t *testing.T) {
|
||||||
|
installDesktopTestHooks(t, nil, func(h *desktopRuntimeHooks) {
|
||||||
|
h.loadConfig = func() (*config.Config, config.ConfigMetadata, error) {
|
||||||
|
return nil, config.ConfigMetadata{}, errors.New("yaml 解析失败")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
err := runDesktop(zap.NewNop())
|
||||||
|
requireStartupPhase(t, err, phaseConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunDesktopSingletonFailurePrecedesPortListen(t *testing.T) {
|
||||||
|
cfg := testDesktopConfig(t)
|
||||||
|
lock := &fakeDesktopLock{lockErr: errors.New("已有实例运行")}
|
||||||
|
listenCalled := false
|
||||||
|
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
|
||||||
|
h.newLock = func(string) singletonLocker { return lock }
|
||||||
|
h.listen = func(int) (net.Listener, error) {
|
||||||
|
listenCalled = true
|
||||||
|
return nil, errors.New("不应监听端口")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
err := runDesktop(zap.NewNop())
|
||||||
|
requireStartupPhase(t, err, phaseSingleton)
|
||||||
|
if listenCalled {
|
||||||
|
t.Fatal("单实例锁失败时不应继续监听端口")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunDesktopPortFailureUnlocksSingleton(t *testing.T) {
|
||||||
|
cfg := testDesktopConfig(t)
|
||||||
|
lock := &fakeDesktopLock{}
|
||||||
|
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
|
||||||
|
h.newLock = func(string) singletonLocker { return lock }
|
||||||
|
h.listen = func(int) (net.Listener, error) { return nil, errors.New("bind failed") }
|
||||||
|
})
|
||||||
|
|
||||||
|
err := runDesktop(zap.NewNop())
|
||||||
|
requireStartupPhase(t, err, phasePort)
|
||||||
|
if !lock.unlocked() {
|
||||||
|
t.Fatal("端口监听失败时应释放单实例锁")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunDesktopLoggerFailureClosesListenerAndUnlocks(t *testing.T) {
|
||||||
|
cfg := testDesktopConfig(t)
|
||||||
|
lock := &fakeDesktopLock{}
|
||||||
|
listener := newRecordingListener(t)
|
||||||
|
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
|
||||||
|
h.newLock = func(string) singletonLocker { return lock }
|
||||||
|
h.listen = func(int) (net.Listener, error) { return listener, nil }
|
||||||
|
h.upgradeLogger = func(*zap.Logger, pkgLogger.Config) (*zap.Logger, error) {
|
||||||
|
return nil, errors.New("log permission denied")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
err := runDesktop(zap.NewNop())
|
||||||
|
requireStartupPhase(t, err, phaseLogger)
|
||||||
|
if !listener.closed() {
|
||||||
|
t.Fatal("日志初始化失败时应关闭 listener")
|
||||||
|
}
|
||||||
|
if !lock.unlocked() {
|
||||||
|
t.Fatal("日志初始化失败时应释放单实例锁")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunDesktopDatabaseFailureClassification(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want startupPhase
|
||||||
|
}{
|
||||||
|
{name: "database", err: errors.New("open failed"), want: phaseDatabase},
|
||||||
|
{name: "migration", err: fmt.Errorf("%w: %w", database.ErrMigration, errors.New("goose failed")), want: phaseMigration},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := testDesktopConfig(t)
|
||||||
|
lock := &fakeDesktopLock{}
|
||||||
|
listener := newRecordingListener(t)
|
||||||
|
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
|
||||||
|
h.newLock = func(string) singletonLocker { return lock }
|
||||||
|
h.listen = func(int) (net.Listener, error) { return listener, nil }
|
||||||
|
h.initDB = func(*config.DatabaseConfig, *zap.Logger) (*gorm.DB, error) { return nil, tt.err }
|
||||||
|
})
|
||||||
|
|
||||||
|
err := runDesktop(zap.NewNop())
|
||||||
|
requireStartupPhase(t, err, tt.want)
|
||||||
|
if !listener.closed() {
|
||||||
|
t.Fatal("数据库失败时应关闭 listener")
|
||||||
|
}
|
||||||
|
if !lock.unlocked() {
|
||||||
|
t.Fatal("数据库失败时应释放单实例锁")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunDesktopInternalStartupFailurePhasesAndDatabaseCleanup(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mutate func(*desktopRuntimeHooks)
|
||||||
|
want startupPhase
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "adapter",
|
||||||
|
mutate: func(h *desktopRuntimeHooks) {
|
||||||
|
h.registerAdapters = func(conversion.AdapterRegistry) error { return errors.New("adapter failed") }
|
||||||
|
},
|
||||||
|
want: phaseAdapter,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "static",
|
||||||
|
mutate: func(h *desktopRuntimeHooks) {
|
||||||
|
h.setupStaticFiles = func(*gin.Engine) error { return errors.New("missing frontend") }
|
||||||
|
},
|
||||||
|
want: phaseStaticResource,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server",
|
||||||
|
mutate: func(h *desktopRuntimeHooks) {
|
||||||
|
h.startServer = func(_ *http.Server, _ net.Listener, errCh chan<- error, _ *zap.Logger) {
|
||||||
|
errCh <- errors.New("serve failed")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: phaseServer,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tray",
|
||||||
|
mutate: func(h *desktopRuntimeHooks) {
|
||||||
|
h.setupSystray = func(int, <-chan error) error {
|
||||||
|
return newStartupError(phaseTray, "托盘初始化失败", errors.New("tray failed"))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: phaseTray,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := testDesktopConfig(t)
|
||||||
|
lock := &fakeDesktopLock{}
|
||||||
|
listener := newRecordingListener(t)
|
||||||
|
closeDBCalled := false
|
||||||
|
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
|
||||||
|
h.newLock = func(string) singletonLocker { return lock }
|
||||||
|
h.listen = func(int) (net.Listener, error) { return listener, nil }
|
||||||
|
h.closeDB = func(db *gorm.DB) {
|
||||||
|
closeDBCalled = true
|
||||||
|
database.Close(db)
|
||||||
|
}
|
||||||
|
tt.mutate(h)
|
||||||
|
})
|
||||||
|
|
||||||
|
err := runDesktop(zap.NewNop())
|
||||||
|
requireStartupPhase(t, err, tt.want)
|
||||||
|
if !closeDBCalled {
|
||||||
|
t.Fatal("数据库初始化后的启动失败应关闭数据库")
|
||||||
|
}
|
||||||
|
if !listener.closed() {
|
||||||
|
t.Fatal("数据库初始化后的启动失败应关闭 listener")
|
||||||
|
}
|
||||||
|
if !lock.unlocked() {
|
||||||
|
t.Fatal("数据库初始化后的启动失败应释放单实例锁")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunDesktopBrowserFailureRemainsNonFatal(t *testing.T) {
|
||||||
|
controller := newFakeTrayController()
|
||||||
|
notified := make(chan string, 1)
|
||||||
|
controller.run = func(onReady func(), _ func()) {
|
||||||
|
onReady()
|
||||||
|
<-controller.quitCh
|
||||||
|
}
|
||||||
|
|
||||||
|
err := runSystray(19826, trayOptions{
|
||||||
|
controller: controller,
|
||||||
|
readyTimeout: time.Second,
|
||||||
|
iconLoader: func() ([]byte, error) { return []byte("icon"), nil },
|
||||||
|
openBrowser: func(string) error { return errors.New("no browser") },
|
||||||
|
notify: func(_, message string) {
|
||||||
|
notified <- message
|
||||||
|
controller.Quit()
|
||||||
|
},
|
||||||
|
logger: zap.NewNop(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("浏览器打开失败不应导致 runSystray 返回 fatal: %v", err)
|
||||||
|
}
|
||||||
|
if got := <-notified; got == "" {
|
||||||
|
t.Fatal("浏览器打开失败应提示用户手动访问")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,7 +14,11 @@ func TestSingletonLock_FirstLockSuccess(t *testing.T) {
|
|||||||
if err := lock.Lock(); err != nil {
|
if err := lock.Lock(); err != nil {
|
||||||
t.Fatalf("首次加锁应成功,但返回错误: %v", err)
|
t.Fatalf("首次加锁应成功,但返回错误: %v", err)
|
||||||
}
|
}
|
||||||
defer lock.Unlock()
|
defer func() {
|
||||||
|
if err := lock.Unlock(); err != nil {
|
||||||
|
t.Fatalf("解锁失败: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSingletonLock_DuplicateLockFails(t *testing.T) {
|
func TestSingletonLock_DuplicateLockFails(t *testing.T) {
|
||||||
@@ -25,12 +29,18 @@ func TestSingletonLock_DuplicateLockFails(t *testing.T) {
|
|||||||
if err := lock1.Lock(); err != nil {
|
if err := lock1.Lock(); err != nil {
|
||||||
t.Fatalf("首次加锁应成功: %v", err)
|
t.Fatalf("首次加锁应成功: %v", err)
|
||||||
}
|
}
|
||||||
defer lock1.Unlock()
|
defer func() {
|
||||||
|
if err := lock1.Unlock(); err != nil {
|
||||||
|
t.Fatalf("解锁失败: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
lock2 := NewSingletonLock(lockPath)
|
lock2 := NewSingletonLock(lockPath)
|
||||||
err := lock2.Lock()
|
err := lock2.Lock()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
lock2.Unlock()
|
if unlockErr := lock2.Unlock(); unlockErr != nil {
|
||||||
|
t.Fatalf("解锁失败: %v", unlockErr)
|
||||||
|
}
|
||||||
t.Fatal("重复加锁应失败,但返回 nil")
|
t.Fatal("重复加锁应失败,但返回 nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -43,16 +53,22 @@ func TestSingletonLock_UnlockThenRelock(t *testing.T) {
|
|||||||
if err := lock1.Lock(); err != nil {
|
if err := lock1.Lock(); err != nil {
|
||||||
t.Fatalf("首次加锁应成功: %v", err)
|
t.Fatalf("首次加锁应成功: %v", err)
|
||||||
}
|
}
|
||||||
lock1.Unlock()
|
if err := lock1.Unlock(); err != nil {
|
||||||
|
t.Fatalf("解锁失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
lock2 := NewSingletonLock(lockPath)
|
lock2 := NewSingletonLock(lockPath)
|
||||||
if err := lock2.Lock(); err != nil {
|
if err := lock2.Lock(); err != nil {
|
||||||
t.Fatalf("释放后重新加锁应成功: %v", err)
|
t.Fatalf("释放后重新加锁应成功: %v", err)
|
||||||
}
|
}
|
||||||
lock2.Unlock()
|
if err := lock2.Unlock(); err != nil {
|
||||||
|
t.Fatalf("解锁失败: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSingletonLock_UnlockWithoutLock(t *testing.T) {
|
func TestSingletonLock_UnlockWithoutLock(t *testing.T) {
|
||||||
lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock"))
|
lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock"))
|
||||||
lock.Unlock()
|
if err := lock.Unlock(); err != nil {
|
||||||
|
t.Fatalf("未加锁时解锁失败: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
96
backend/cmd/desktop/startup_error.go
Normal file
96
backend/cmd/desktop/startup_error.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type startupPhase string
|
||||||
|
|
||||||
|
const (
|
||||||
|
phaseConfig startupPhase = "config"
|
||||||
|
phaseSingleton startupPhase = "singleton"
|
||||||
|
phasePort startupPhase = "port"
|
||||||
|
phaseLogger startupPhase = "logger"
|
||||||
|
phaseDatabase startupPhase = "database"
|
||||||
|
phaseMigration startupPhase = "migration"
|
||||||
|
phaseAdapter startupPhase = "adapter"
|
||||||
|
phaseStaticResource startupPhase = "static"
|
||||||
|
phaseServer startupPhase = "server"
|
||||||
|
phaseTray startupPhase = "tray"
|
||||||
|
)
|
||||||
|
|
||||||
|
type startupError struct {
|
||||||
|
phase startupPhase
|
||||||
|
message string
|
||||||
|
cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStartupError(phase startupPhase, message string, cause error) *startupError {
|
||||||
|
return &startupError{
|
||||||
|
phase: phase,
|
||||||
|
message: redactSensitive(message),
|
||||||
|
cause: cause,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *startupError) Error() string {
|
||||||
|
if e == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if e.cause == nil {
|
||||||
|
return fmt.Sprintf("%s: %s", e.phase, e.message)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s: %s: %s", e.phase, e.message, redactSensitive(e.cause.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *startupError) Unwrap() error {
|
||||||
|
if e == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return e.cause
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *startupError) Phase() string {
|
||||||
|
if e == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(e.phase)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *startupError) UserMessage() string {
|
||||||
|
if e == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return redactSensitive(e.message)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sensitiveReplacers = []struct {
|
||||||
|
pattern *regexp.Regexp
|
||||||
|
replacement string
|
||||||
|
}{
|
||||||
|
{regexp.MustCompile(`(?i)(password\s*[:=]\s*)[^\s,;&]+`), `${1}<redacted>`},
|
||||||
|
{regexp.MustCompile(`(?i)(api[_-]?key\s*[:=]\s*)[^\s,;&]+`), `${1}<redacted>`},
|
||||||
|
{regexp.MustCompile(`(?i)(secret\s*[:=]\s*)[^\s,;&]+`), `${1}<redacted>`},
|
||||||
|
{regexp.MustCompile(`([^\s:/]+):([^\s@]+)@tcp\(`), `${1}:<redacted>@tcp(`},
|
||||||
|
{regexp.MustCompile(`(://[^\s:/]+):([^\s@]+)@`), `${1}:<redacted>@`},
|
||||||
|
}
|
||||||
|
|
||||||
|
func redactSensitive(s string) string {
|
||||||
|
for _, replacer := range sensitiveReplacers {
|
||||||
|
s = replacer.pattern.ReplaceAllString(s, replacer.replacement)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func startupTitle() string {
|
||||||
|
return appName + " 启动失败"
|
||||||
|
}
|
||||||
|
|
||||||
|
func startupServerErrorMessage() string {
|
||||||
|
return "后端服务启动失败\n\n请检查端口占用、网络权限或查看日志获取更多信息"
|
||||||
|
}
|
||||||
|
|
||||||
|
func startupInternalErrorMessage() string {
|
||||||
|
return "应用初始化失败\n\n请查看日志或重新安装应用"
|
||||||
|
}
|
||||||
40
backend/cmd/desktop/startup_error_test.go
Normal file
40
backend/cmd/desktop/startup_error_test.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStartupErrorContainsPhaseAndCause(t *testing.T) {
|
||||||
|
cause := errors.New("底层失败")
|
||||||
|
err := newStartupError(phaseDatabase, "数据库初始化失败", cause)
|
||||||
|
|
||||||
|
if err.Phase() != "database" {
|
||||||
|
t.Fatalf("phase = %q, want database", err.Phase())
|
||||||
|
}
|
||||||
|
if !errors.Is(err, cause) {
|
||||||
|
t.Fatal("startupError 应保留底层 cause")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "database") {
|
||||||
|
t.Fatalf("错误字符串应包含 phase,实际: %s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartupErrorRedactsSensitiveUserMessage(t *testing.T) {
|
||||||
|
message := "数据库初始化失败: nex:secret@tcp(localhost:3306)/nex password=secret api_key=sk-test"
|
||||||
|
err := newStartupError(phaseDatabase, message, errors.New("cause password=secret api_key=sk-test"))
|
||||||
|
userMessage := err.UserMessage()
|
||||||
|
|
||||||
|
for _, secret := range []string{"secret", "sk-test"} {
|
||||||
|
if strings.Contains(userMessage, secret) {
|
||||||
|
t.Fatalf("用户提示不应包含敏感信息 %q,实际: %s", secret, userMessage)
|
||||||
|
}
|
||||||
|
if strings.Contains(err.Error(), secret) {
|
||||||
|
t.Fatalf("日志错误字符串不应包含敏感信息 %q,实际: %s", secret, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !strings.Contains(userMessage, "<redacted>") {
|
||||||
|
t.Fatalf("用户提示应包含脱敏占位符,实际: %s", userMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,72 +1,26 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io/fs"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"testing/fstest"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"nex/embedfs"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSetupStaticFiles(t *testing.T) {
|
func TestSetupStaticFiles(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
|
||||||
if err != nil {
|
|
||||||
t.Skipf("跳过测试: 前端资源未构建: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
getContentType := func(path string) string {
|
|
||||||
if strings.HasSuffix(path, ".js") {
|
|
||||||
return "application/javascript"
|
|
||||||
}
|
|
||||||
if strings.HasSuffix(path, ".css") {
|
|
||||||
return "text/css"
|
|
||||||
}
|
|
||||||
if strings.HasSuffix(path, ".svg") {
|
|
||||||
return "image/svg+xml"
|
|
||||||
}
|
|
||||||
return "application/octet-stream"
|
|
||||||
}
|
|
||||||
|
|
||||||
r := gin.New()
|
r := gin.New()
|
||||||
r.GET("/assets/*filepath", func(c *gin.Context) {
|
setupStaticFilesWithFS(r, fstest.MapFS{
|
||||||
filepath := c.Param("filepath")
|
"index.html": {Data: []byte("<html>fallback</html>")},
|
||||||
data, err := fs.ReadFile(distFS, "assets"+filepath)
|
"icon.png": {Data: []byte("png")},
|
||||||
if err != nil {
|
"assets/test.js": {Data: []byte("console.log('test')")},
|
||||||
c.Status(404)
|
"assets/test.css": {Data: []byte("body {}")},
|
||||||
return
|
"assets/test.svg": {Data: []byte("<svg></svg>")},
|
||||||
}
|
"assets/test.woff": {Data: []byte("font")},
|
||||||
c.Data(200, getContentType(filepath), data)
|
|
||||||
})
|
|
||||||
|
|
||||||
r.GET("/favicon.svg", func(c *gin.Context) {
|
|
||||||
data, err := fs.ReadFile(distFS, "favicon.svg")
|
|
||||||
if err != nil {
|
|
||||||
c.Status(404)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Data(200, "image/svg+xml", data)
|
|
||||||
})
|
|
||||||
|
|
||||||
r.NoRoute(func(c *gin.Context) {
|
|
||||||
path := c.Request.URL.Path
|
|
||||||
if strings.HasPrefix(path, "/api/") ||
|
|
||||||
strings.HasPrefix(path, "/v1/") ||
|
|
||||||
strings.HasPrefix(path, "/health") {
|
|
||||||
c.JSON(404, gin.H{"error": "not found"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
data, err := fs.ReadFile(distFS, "index.html")
|
|
||||||
if err != nil {
|
|
||||||
c.Status(500)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Data(200, "text/html; charset=utf-8", data)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("API 404", func(t *testing.T) {
|
t.Run("API 404", func(t *testing.T) {
|
||||||
@@ -79,6 +33,32 @@ func TestSetupStaticFiles(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI proxy prefix 404", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/openai/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("期望状态码 404, 实际 %d", w.Code)
|
||||||
|
}
|
||||||
|
if !strings.Contains(w.Body.String(), "not found") {
|
||||||
|
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Anthropic proxy prefix 404", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/anthropic/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("期望状态码 404, 实际 %d", w.Code)
|
||||||
|
}
|
||||||
|
if !strings.Contains(w.Body.String(), "not found") {
|
||||||
|
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("SPA fallback", func(t *testing.T) {
|
t.Run("SPA fallback", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "/providers", nil)
|
req := httptest.NewRequest("GET", "/providers", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -94,14 +74,13 @@ func TestSetupStaticFiles(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r.ServeHTTP(w, req)
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
if w.Code == 200 {
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("期望状态码 200, 实际 %d", w.Code)
|
||||||
|
}
|
||||||
expected := "application/javascript"
|
expected := "application/javascript"
|
||||||
if w.Header().Get("Content-Type") != expected {
|
if w.Header().Get("Content-Type") != expected {
|
||||||
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
t.Log("文件不存在,跳过 MIME 类型验证")
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("MIME type for CSS", func(t *testing.T) {
|
t.Run("MIME type for CSS", func(t *testing.T) {
|
||||||
@@ -109,15 +88,144 @@ func TestSetupStaticFiles(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r.ServeHTTP(w, req)
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
if w.Code == 200 {
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("期望状态码 200, 实际 %d", w.Code)
|
||||||
|
}
|
||||||
expected := "text/css"
|
expected := "text/css"
|
||||||
if w.Header().Get("Content-Type") != expected {
|
if w.Header().Get("Content-Type") != expected {
|
||||||
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
t.Log("文件不存在,跳过 MIME 类型验证")
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Log("静态文件服务测试通过")
|
t.Log("静态文件服务测试通过")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetupStaticFilesWithFS_IconPNG(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
setupStaticFilesWithFS(r, fstest.MapFS{
|
||||||
|
"icon.png": {Data: []byte("png")},
|
||||||
|
"index.html": {Data: []byte("<html>fallback</html>")},
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/icon.png", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("期望状态码 200, 实际 %d", w.Code)
|
||||||
|
}
|
||||||
|
if w.Header().Get("Content-Type") != "image/png" {
|
||||||
|
t.Fatalf("期望 Content-Type image/png, 实际 %s", w.Header().Get("Content-Type"))
|
||||||
|
}
|
||||||
|
if w.Body.String() != "png" {
|
||||||
|
t.Fatalf("期望返回 PNG 内容,实际 %q", w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithProtocolAndStaticRoutes(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
|
||||||
|
var gotProtocol string
|
||||||
|
var gotPath string
|
||||||
|
r.Any("/openai/*path", withProtocol("openai", func(c *gin.Context) {
|
||||||
|
gotProtocol = c.Param("protocol")
|
||||||
|
gotPath = c.Param("path")
|
||||||
|
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
|
||||||
|
}))
|
||||||
|
r.Any("/anthropic/*path", withProtocol("anthropic", func(c *gin.Context) {
|
||||||
|
gotProtocol = c.Param("protocol")
|
||||||
|
gotPath = c.Param("path")
|
||||||
|
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
|
||||||
|
}))
|
||||||
|
setupStaticFilesWithFS(r, fstest.MapFS{
|
||||||
|
"index.html": {Data: []byte("<html>fallback</html>")},
|
||||||
|
"assets/test.js": {Data: []byte("console.log('test')")},
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI route enters proxy handler wrapper", func(t *testing.T) {
|
||||||
|
gotProtocol = ""
|
||||||
|
gotPath = ""
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/openai/v1/chat/completions", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||||
|
}
|
||||||
|
if gotProtocol != "openai" {
|
||||||
|
t.Errorf("期望 protocol=openai, 实际 %s", gotProtocol)
|
||||||
|
}
|
||||||
|
if gotPath != "/v1/chat/completions" {
|
||||||
|
t.Errorf("期望 path=/v1/chat/completions, 实际 %s", gotPath)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Anthropic route enters proxy handler wrapper", func(t *testing.T) {
|
||||||
|
gotProtocol = ""
|
||||||
|
gotPath = ""
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/anthropic/v1/messages", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||||
|
}
|
||||||
|
if gotProtocol != "anthropic" {
|
||||||
|
t.Errorf("期望 protocol=anthropic, 实际 %s", gotProtocol)
|
||||||
|
}
|
||||||
|
if gotPath != "/v1/messages" {
|
||||||
|
t.Errorf("期望 path=/v1/messages, 实际 %s", gotPath)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Static assets are not hijacked", func(t *testing.T) {
|
||||||
|
gotProtocol = ""
|
||||||
|
gotPath = ""
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/assets/test.js", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if gotProtocol != "" || gotPath != "" {
|
||||||
|
t.Errorf("静态资源不应进入代理包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
|
||||||
|
}
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("期望静态资源返回 200, 实际 %d", w.Code)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(w.Header().Get("Content-Type"), "application/javascript") {
|
||||||
|
t.Errorf("期望 JS Content-Type, 实际 %s", w.Header().Get("Content-Type"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SPA path keeps fallback", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/providers", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||||
|
}
|
||||||
|
if !strings.Contains(w.Header().Get("Content-Type"), "text/html") {
|
||||||
|
t.Errorf("期望返回 HTML,实际 %s", w.Header().Get("Content-Type"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Unknown proxy-like path does not return index html", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/openai/unknown", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("显式代理路由应进入代理包装器,实际状态码 %d", w.Code)
|
||||||
|
}
|
||||||
|
if gotProtocol != "openai" || gotPath != "/unknown" {
|
||||||
|
t.Errorf("期望 unknown 代理路径进入 openai 包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
231
backend/cmd/desktop/tray.go
Normal file
231
backend/cmd/desktop/tray.go
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"nex/embedfs"
|
||||||
|
|
||||||
|
"github.com/getlantern/systray"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultTrayReadyTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
type trayMenuItem interface {
|
||||||
|
Disable()
|
||||||
|
Clicked() <-chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type trayController interface {
|
||||||
|
Run(onReady func(), onExit func())
|
||||||
|
Quit()
|
||||||
|
SetIcon(icon []byte)
|
||||||
|
SetTooltip(tooltip string)
|
||||||
|
AddMenuItem(title, tooltip string) trayMenuItem
|
||||||
|
AddSeparator()
|
||||||
|
}
|
||||||
|
|
||||||
|
type realTrayController struct{}
|
||||||
|
|
||||||
|
func (realTrayController) Run(onReady func(), onExit func()) {
|
||||||
|
systray.Run(onReady, onExit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (realTrayController) Quit() {
|
||||||
|
systray.Quit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (realTrayController) SetIcon(icon []byte) {
|
||||||
|
systray.SetIcon(icon)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (realTrayController) SetTooltip(tooltip string) {
|
||||||
|
systray.SetTooltip(tooltip)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (realTrayController) AddMenuItem(title, tooltip string) trayMenuItem {
|
||||||
|
return realTrayMenuItem{item: systray.AddMenuItem(title, tooltip)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (realTrayController) AddSeparator() {
|
||||||
|
systray.AddSeparator()
|
||||||
|
}
|
||||||
|
|
||||||
|
type realTrayMenuItem struct {
|
||||||
|
item *systray.MenuItem
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m realTrayMenuItem) Disable() {
|
||||||
|
m.item.Disable()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m realTrayMenuItem) Clicked() <-chan struct{} {
|
||||||
|
return m.item.ClickedCh
|
||||||
|
}
|
||||||
|
|
||||||
|
type trayOptions struct {
|
||||||
|
controller trayController
|
||||||
|
readyTimeout time.Duration
|
||||||
|
iconLoader func() ([]byte, error)
|
||||||
|
openBrowser func(string) error
|
||||||
|
notify func(string, string)
|
||||||
|
logger *zap.Logger
|
||||||
|
fatalErrCh <-chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupSystray(port int, fatalErrCh <-chan error) error {
|
||||||
|
return runSystray(port, trayOptions{
|
||||||
|
controller: realTrayController{},
|
||||||
|
readyTimeout: defaultTrayReadyTimeout,
|
||||||
|
iconLoader: loadTrayIcon,
|
||||||
|
openBrowser: openBrowser,
|
||||||
|
notify: showError,
|
||||||
|
logger: dialogLogger(),
|
||||||
|
fatalErrCh: fatalErrCh,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func runSystray(port int, opts trayOptions) error {
|
||||||
|
if opts.controller == nil {
|
||||||
|
opts.controller = realTrayController{}
|
||||||
|
}
|
||||||
|
if opts.readyTimeout <= 0 {
|
||||||
|
opts.readyTimeout = defaultTrayReadyTimeout
|
||||||
|
}
|
||||||
|
if opts.iconLoader == nil {
|
||||||
|
opts.iconLoader = loadTrayIcon
|
||||||
|
}
|
||||||
|
if opts.openBrowser == nil {
|
||||||
|
opts.openBrowser = openBrowser
|
||||||
|
}
|
||||||
|
if opts.notify == nil {
|
||||||
|
opts.notify = showError
|
||||||
|
}
|
||||||
|
if opts.logger == nil {
|
||||||
|
opts.logger = dialogLogger()
|
||||||
|
}
|
||||||
|
|
||||||
|
readyCh := make(chan struct{})
|
||||||
|
doneCh := make(chan struct{})
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
var readyOnce sync.Once
|
||||||
|
var errOnce sync.Once
|
||||||
|
|
||||||
|
signalReady := func() {
|
||||||
|
readyOnce.Do(func() { close(readyCh) })
|
||||||
|
}
|
||||||
|
signalError := func(err error) {
|
||||||
|
errOnce.Do(func() { errCh <- err })
|
||||||
|
}
|
||||||
|
|
||||||
|
go monitorTrayStartup(port, opts, readyCh, doneCh, signalError)
|
||||||
|
|
||||||
|
opts.controller.Run(func() {
|
||||||
|
handleTrayReady(port, opts, signalReady, signalError)
|
||||||
|
}, nil)
|
||||||
|
close(doneCh)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
return err
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func monitorTrayStartup(port int, opts trayOptions, readyCh <-chan struct{}, doneCh <-chan struct{}, signalError func(error)) {
|
||||||
|
timer := time.NewTimer(opts.readyTimeout)
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
ready := false
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-readyCh:
|
||||||
|
ready = true
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
openDesktopBrowser(port, opts)
|
||||||
|
readyCh = nil
|
||||||
|
case <-timer.C:
|
||||||
|
if !ready {
|
||||||
|
signalError(newStartupError(phaseTray, "托盘初始化超时", fmt.Errorf("托盘未在 %s 内 ready", opts.readyTimeout)))
|
||||||
|
opts.controller.Quit()
|
||||||
|
}
|
||||||
|
case err := <-opts.fatalErrCh:
|
||||||
|
if err != nil {
|
||||||
|
signalError(newStartupError(phaseServer, startupServerErrorMessage(), err))
|
||||||
|
opts.controller.Quit()
|
||||||
|
}
|
||||||
|
case <-doneCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleTrayReady(port int, opts trayOptions, signalReady func(), signalError func(error)) {
|
||||||
|
defer func() {
|
||||||
|
if recovered := recover(); recovered != nil {
|
||||||
|
err := fmt.Errorf("托盘初始化 panic: %v", recovered)
|
||||||
|
signalError(newStartupError(phaseTray, "托盘菜单初始化失败", err))
|
||||||
|
opts.controller.Quit()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
icon, err := opts.iconLoader()
|
||||||
|
if err != nil {
|
||||||
|
signalError(newStartupError(phaseTray, "托盘图标资源无法加载", err))
|
||||||
|
opts.controller.Quit()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
opts.controller.SetIcon(icon)
|
||||||
|
opts.controller.SetTooltip(appTooltip)
|
||||||
|
|
||||||
|
mOpen := opts.controller.AddMenuItem("打开管理界面", "在浏览器中打开")
|
||||||
|
opts.controller.AddSeparator()
|
||||||
|
mStatus := opts.controller.AddMenuItem("状态: 运行中", "")
|
||||||
|
mStatus.Disable()
|
||||||
|
mPort := opts.controller.AddMenuItem(desktopPortMenuTitle(port), "")
|
||||||
|
mPort.Disable()
|
||||||
|
opts.controller.AddSeparator()
|
||||||
|
mQuit := opts.controller.AddMenuItem("退出", "停止服务并退出")
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-mOpen.Clicked():
|
||||||
|
if err := opts.openBrowser(desktopURL(port)); err != nil {
|
||||||
|
opts.logger.Warn("打开浏览器失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
case <-mQuit.Clicked():
|
||||||
|
doShutdown()
|
||||||
|
opts.controller.Quit()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
signalReady()
|
||||||
|
}
|
||||||
|
|
||||||
|
func openDesktopBrowser(port int, opts trayOptions) {
|
||||||
|
url := desktopURL(port)
|
||||||
|
if err := opts.openBrowser(url); err != nil {
|
||||||
|
opts.logger.Warn("无法打开浏览器", zap.Error(err))
|
||||||
|
opts.notify(appName, fmt.Sprintf("无法自动打开浏览器,请手动访问 %s", url))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadTrayIcon() ([]byte, error) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
return embedfs.Assets.ReadFile("assets/icon.ico")
|
||||||
|
}
|
||||||
|
return embedfs.Assets.ReadFile("assets/icon.png")
|
||||||
|
}
|
||||||
169
backend/cmd/desktop/tray_test.go
Normal file
169
backend/cmd/desktop/tray_test.go
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeTrayController struct {
|
||||||
|
run func(onReady func(), onExit func())
|
||||||
|
|
||||||
|
quitCh chan struct{}
|
||||||
|
quitOnce sync.Once
|
||||||
|
|
||||||
|
icon []byte
|
||||||
|
tooltip string
|
||||||
|
menuItems []*fakeTrayMenuItem
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeTrayController() *fakeTrayController {
|
||||||
|
return &fakeTrayController{quitCh: make(chan struct{})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeTrayController) Run(onReady func(), onExit func()) {
|
||||||
|
if c.run != nil {
|
||||||
|
c.run(onReady, onExit)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
onReady()
|
||||||
|
<-c.quitCh
|
||||||
|
if onExit != nil {
|
||||||
|
onExit()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeTrayController) Quit() {
|
||||||
|
c.quitOnce.Do(func() { close(c.quitCh) })
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeTrayController) SetIcon(icon []byte) {
|
||||||
|
c.icon = append([]byte(nil), icon...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeTrayController) SetTooltip(tooltip string) {
|
||||||
|
c.tooltip = tooltip
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeTrayController) AddMenuItem(title, tooltip string) trayMenuItem {
|
||||||
|
item := &fakeTrayMenuItem{clicked: make(chan struct{}), title: title, tooltip: tooltip}
|
||||||
|
c.menuItems = append(c.menuItems, item)
|
||||||
|
return item
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeTrayController) AddSeparator() {}
|
||||||
|
|
||||||
|
type fakeTrayMenuItem struct {
|
||||||
|
clicked chan struct{}
|
||||||
|
title string
|
||||||
|
tooltip string
|
||||||
|
disabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fakeTrayMenuItem) Disable() {
|
||||||
|
m.disabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fakeTrayMenuItem) Clicked() <-chan struct{} {
|
||||||
|
return m.clicked
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunSystrayReadyOpensBrowser(t *testing.T) {
|
||||||
|
controller := newFakeTrayController()
|
||||||
|
opened := make(chan string, 1)
|
||||||
|
|
||||||
|
err := runSystray(19826, trayOptions{
|
||||||
|
controller: controller,
|
||||||
|
readyTimeout: time.Second,
|
||||||
|
iconLoader: func() ([]byte, error) { return []byte("icon"), nil },
|
||||||
|
openBrowser: func(url string) error {
|
||||||
|
opened <- url
|
||||||
|
controller.Quit()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
notify: func(string, string) {},
|
||||||
|
logger: zap.NewNop(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("托盘 ready 成功不应返回错误: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := <-opened; got != "http://localhost:19826" {
|
||||||
|
t.Fatalf("浏览器 URL = %s", got)
|
||||||
|
}
|
||||||
|
if string(controller.icon) != "icon" {
|
||||||
|
t.Fatalf("应设置托盘图标")
|
||||||
|
}
|
||||||
|
if controller.tooltip != appTooltip {
|
||||||
|
t.Fatalf("tooltip = %q, want %q", controller.tooltip, appTooltip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunSystrayReadyTimeoutReturnsTrayStartupError(t *testing.T) {
|
||||||
|
controller := newFakeTrayController()
|
||||||
|
controller.run = func(_ func(), _ func()) {
|
||||||
|
<-controller.quitCh
|
||||||
|
}
|
||||||
|
|
||||||
|
err := runSystray(19826, trayOptions{
|
||||||
|
controller: controller,
|
||||||
|
readyTimeout: 10 * time.Millisecond,
|
||||||
|
iconLoader: func() ([]byte, error) { return []byte("icon"), nil },
|
||||||
|
openBrowser: func(string) error { return nil },
|
||||||
|
notify: func(string, string) {},
|
||||||
|
logger: zap.NewNop(),
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("托盘 ready timeout 应返回错误")
|
||||||
|
}
|
||||||
|
var startupErr *startupError
|
||||||
|
if !errors.As(err, &startupErr) || startupErr.Phase() != "tray" {
|
||||||
|
t.Fatalf("应返回 tray 阶段启动错误,实际: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunSystrayIconLoadFailureReturnsTrayStartupError(t *testing.T) {
|
||||||
|
controller := newFakeTrayController()
|
||||||
|
|
||||||
|
err := runSystray(19826, trayOptions{
|
||||||
|
controller: controller,
|
||||||
|
readyTimeout: time.Second,
|
||||||
|
iconLoader: func() ([]byte, error) { return nil, errors.New("missing icon") },
|
||||||
|
openBrowser: func(string) error { return nil },
|
||||||
|
notify: func(string, string) {},
|
||||||
|
logger: zap.NewNop(),
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("托盘图标加载失败应返回错误")
|
||||||
|
}
|
||||||
|
var startupErr *startupError
|
||||||
|
if !errors.As(err, &startupErr) || startupErr.Phase() != "tray" {
|
||||||
|
t.Fatalf("应返回 tray 阶段启动错误,实际: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunSystrayBrowserOpenFailureIsNonFatal(t *testing.T) {
|
||||||
|
controller := newFakeTrayController()
|
||||||
|
notified := make(chan string, 1)
|
||||||
|
|
||||||
|
err := runSystray(19826, trayOptions{
|
||||||
|
controller: controller,
|
||||||
|
readyTimeout: time.Second,
|
||||||
|
iconLoader: func() ([]byte, error) { return []byte("icon"), nil },
|
||||||
|
openBrowser: func(string) error { return errors.New("no browser") },
|
||||||
|
notify: func(_, message string) {
|
||||||
|
notified <- message
|
||||||
|
controller.Quit()
|
||||||
|
},
|
||||||
|
logger: zap.NewNop(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("浏览器打开失败不应成为 fatal: %v", err)
|
||||||
|
}
|
||||||
|
if got := <-notified; got == "" {
|
||||||
|
t.Fatal("浏览器打开失败应提示用户")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,46 +3,38 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pressly/goose/v3"
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"gorm.io/driver/sqlite"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"gorm.io/gorm/logger"
|
|
||||||
|
|
||||||
"nex/backend/internal/config"
|
"nex/backend/internal/config"
|
||||||
"nex/backend/internal/conversion"
|
"nex/backend/internal/conversion"
|
||||||
"nex/backend/internal/conversion/anthropic"
|
"nex/backend/internal/conversion/anthropic"
|
||||||
"nex/backend/internal/conversion/openai"
|
"nex/backend/internal/conversion/openai"
|
||||||
|
"nex/backend/internal/database"
|
||||||
"nex/backend/internal/handler"
|
"nex/backend/internal/handler"
|
||||||
"nex/backend/internal/handler/middleware"
|
"nex/backend/internal/handler/middleware"
|
||||||
"nex/backend/internal/provider"
|
"nex/backend/internal/provider"
|
||||||
"nex/backend/internal/repository"
|
"nex/backend/internal/repository"
|
||||||
"nex/backend/internal/service"
|
"nex/backend/internal/service"
|
||||||
|
"nex/backend/pkg/buildinfo"
|
||||||
pkgLogger "nex/backend/pkg/logger"
|
pkgLogger "nex/backend/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// 1. 加载配置(已包含 CLI 参数解析、环境变量绑定、配置文件读取和验证)
|
minimalLogger := pkgLogger.NewMinimal()
|
||||||
cfg, err := config.LoadConfig()
|
|
||||||
|
cfg, cfgMeta, err := config.LoadServerConfigWithMetadata()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("加载配置失败: %v", err)
|
minimalLogger.Fatal("加载配置失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 打印配置摘要
|
zapLogger, err := pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
|
||||||
cfg.PrintSummary()
|
|
||||||
|
|
||||||
// 3. 初始化日志
|
|
||||||
zapLogger, err := pkgLogger.New(pkgLogger.Config{
|
|
||||||
Level: cfg.Log.Level,
|
Level: cfg.Log.Level,
|
||||||
Path: cfg.Log.Path,
|
Path: cfg.Log.Path,
|
||||||
MaxSize: cfg.Log.MaxSize,
|
MaxSize: cfg.Log.MaxSize,
|
||||||
@@ -51,60 +43,59 @@ func main() {
|
|||||||
Compress: cfg.Log.Compress,
|
Compress: cfg.Log.Compress,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("初始化日志失败: %v", err)
|
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
defer zapLogger.Sync()
|
defer func() {
|
||||||
|
if err := zapLogger.Sync(); err != nil {
|
||||||
|
minimalLogger.Warn("同步日志失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// 3. 初始化数据库
|
cfg.PrintSummary(zapLogger)
|
||||||
db, err := initDatabase(cfg)
|
|
||||||
|
db, err := database.Init(&cfg.Database, zapLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
zapLogger.Fatal("初始化数据库失败", zap.String("error", err.Error()))
|
zapLogger.Fatal("初始化数据库失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
defer closeDB(db)
|
defer database.Close(db)
|
||||||
|
|
||||||
// 4. 初始化 repository 层
|
|
||||||
providerRepo := repository.NewProviderRepository(db)
|
providerRepo := repository.NewProviderRepository(db)
|
||||||
modelRepo := repository.NewModelRepository(db)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
statsRepo := repository.NewStatsRepository(db)
|
statsRepo := repository.NewStatsRepository(db)
|
||||||
|
|
||||||
// 5. 初始化缓存
|
|
||||||
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
|
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
|
||||||
if err := routingCache.Preload(); err != nil {
|
if err := routingCache.Preload(); err != nil {
|
||||||
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
|
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. 初始化统计缓冲
|
|
||||||
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
|
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
|
||||||
service.WithFlushInterval(5*time.Second),
|
service.WithFlushInterval(5*time.Second),
|
||||||
service.WithFlushThreshold(100))
|
service.WithFlushThreshold(100))
|
||||||
statsBuffer.Start()
|
statsBuffer.Start()
|
||||||
|
|
||||||
// 7. 初始化 service 层
|
|
||||||
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
||||||
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
||||||
routingService := service.NewRoutingService(routingCache)
|
routingService := service.NewRoutingService(routingCache)
|
||||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||||
|
|
||||||
// 8. 创建 ConversionEngine
|
|
||||||
registry := conversion.NewMemoryRegistry()
|
registry := conversion.NewMemoryRegistry()
|
||||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
|
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
if err := registry.Register(anthropic.NewAdapter()); err != nil {
|
if err := registry.Register(anthropic.NewAdapter()); err != nil {
|
||||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error()))
|
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
engine := conversion.NewConversionEngine(registry, zapLogger)
|
engine := conversion.NewConversionEngine(registry, zapLogger)
|
||||||
|
|
||||||
// 9. 初始化 provider client
|
providerClient := provider.NewClient(zapLogger)
|
||||||
providerClient := provider.NewClient()
|
|
||||||
|
|
||||||
// 10. 初始化 handler 层
|
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
|
||||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
|
|
||||||
providerHandler := handler.NewProviderHandler(providerService)
|
providerHandler := handler.NewProviderHandler(providerService)
|
||||||
modelHandler := handler.NewModelHandler(modelService)
|
modelHandler := handler.NewModelHandler(modelService)
|
||||||
statsHandler := handler.NewStatsHandler(statsService)
|
statsHandler := handler.NewStatsHandler(statsService)
|
||||||
|
versionHandler := handler.NewVersionHandler()
|
||||||
|
settingsHandler := handler.NewSettingsHandler(cfg, "server", false, cfgMeta.ConfigPath)
|
||||||
|
|
||||||
// 11. 创建 Gin 引擎
|
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
r := gin.New()
|
r := gin.New()
|
||||||
|
|
||||||
@@ -113,20 +104,23 @@ func main() {
|
|||||||
r.Use(middleware.Logging(zapLogger))
|
r.Use(middleware.Logging(zapLogger))
|
||||||
r.Use(middleware.CORS())
|
r.Use(middleware.CORS())
|
||||||
|
|
||||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
|
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler, versionHandler, settingsHandler)
|
||||||
|
|
||||||
// 12. 启动服务器
|
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: formatAddr(cfg.Server.Port),
|
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
|
||||||
Handler: r,
|
Handler: r,
|
||||||
ReadTimeout: cfg.Server.ReadTimeout,
|
ReadTimeout: cfg.Server.ReadTimeout,
|
||||||
WriteTimeout: cfg.Server.WriteTimeout,
|
WriteTimeout: cfg.Server.WriteTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
zapLogger.Info("AI Gateway 启动", zap.String("addr", srv.Addr))
|
zapLogger.Info("AI Gateway 启动",
|
||||||
|
zap.String("addr", srv.Addr),
|
||||||
|
zap.String("version", buildinfo.Version()),
|
||||||
|
zap.String("commit", buildinfo.Commit()),
|
||||||
|
zap.String("build_time", buildinfo.BuildTime()))
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error()))
|
zapLogger.Fatal("服务器启动失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -140,7 +134,7 @@ func main() {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := srv.Shutdown(ctx); err != nil {
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error()))
|
zapLogger.Fatal("服务器强制关闭", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
statsBuffer.Stop()
|
statsBuffer.Stop()
|
||||||
@@ -148,83 +142,10 @@ func main() {
|
|||||||
zapLogger.Info("服务器已关闭")
|
zapLogger.Info("服务器已关闭")
|
||||||
}
|
}
|
||||||
|
|
||||||
func initDatabase(cfg *config.Config) (*gorm.DB, error) {
|
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler, versionHandler *handler.VersionHandler, settingsHandler *handler.SettingsHandler) {
|
||||||
db, err := gorm.Open(sqlite.Open(cfg.Database.Path), &gorm.Config{
|
|
||||||
Logger: logger.Default.LogMode(logger.Info),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := runMigrations(db); err != nil {
|
|
||||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
|
||||||
log.Printf("警告: 启用 WAL 模式失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sqlDB, err := db.DB()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns)
|
|
||||||
sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns)
|
|
||||||
sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime)
|
|
||||||
|
|
||||||
log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v",
|
|
||||||
cfg.Database.MaxIdleConns, cfg.Database.MaxOpenConns, cfg.Database.ConnMaxLifetime)
|
|
||||||
|
|
||||||
return db, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func runMigrations(db *gorm.DB) error {
|
|
||||||
sqlDB, err := db.DB()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
migrationsDir := getMigrationsDir()
|
|
||||||
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
|
|
||||||
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
|
|
||||||
}
|
|
||||||
|
|
||||||
goose.SetDialect("sqlite3")
|
|
||||||
if err := goose.Up(sqlDB, migrationsDir); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getMigrationsDir() string {
|
|
||||||
_, filename, _, ok := runtime.Caller(0)
|
|
||||||
if ok {
|
|
||||||
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations")
|
|
||||||
if abs, err := filepath.Abs(dir); err == nil {
|
|
||||||
return abs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "./migrations"
|
|
||||||
}
|
|
||||||
|
|
||||||
func closeDB(db *gorm.DB) {
|
|
||||||
sqlDB, err := db.DB()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
sqlDB.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatAddr(port int) string {
|
|
||||||
return fmt.Sprintf(":%d", port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
|
|
||||||
// 统一代理入口: /{protocol}/{path}
|
|
||||||
r.Any("/:protocol/*path", proxyHandler.HandleProxy)
|
r.Any("/:protocol/*path", proxyHandler.HandleProxy)
|
||||||
|
r.GET("/api/version", versionHandler.GetVersion)
|
||||||
|
|
||||||
// 供应商管理 API
|
|
||||||
providers := r.Group("/api/providers")
|
providers := r.Group("/api/providers")
|
||||||
{
|
{
|
||||||
providers.GET("", providerHandler.ListProviders)
|
providers.GET("", providerHandler.ListProviders)
|
||||||
@@ -234,7 +155,6 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
|
|||||||
providers.DELETE("/:id", providerHandler.DeleteProvider)
|
providers.DELETE("/:id", providerHandler.DeleteProvider)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 模型管理 API
|
|
||||||
models := r.Group("/api/models")
|
models := r.Group("/api/models")
|
||||||
{
|
{
|
||||||
models.GET("", modelHandler.ListModels)
|
models.GET("", modelHandler.ListModels)
|
||||||
@@ -244,14 +164,18 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
|
|||||||
models.DELETE("/:id", modelHandler.DeleteModel)
|
models.DELETE("/:id", modelHandler.DeleteModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 统计查询 API
|
|
||||||
stats := r.Group("/api/stats")
|
stats := r.Group("/api/stats")
|
||||||
{
|
{
|
||||||
stats.GET("", statsHandler.GetStats)
|
stats.GET("", statsHandler.GetStats)
|
||||||
stats.GET("/aggregate", statsHandler.AggregateStats)
|
stats.GET("/aggregate", statsHandler.AggregateStats)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 健康检查
|
settings := r.Group("/api/settings")
|
||||||
|
{
|
||||||
|
settings.GET("/startup", settingsHandler.GetStartupSettings)
|
||||||
|
settings.PUT("/startup", settingsHandler.SaveStartupSettings)
|
||||||
|
}
|
||||||
|
|
||||||
r.GET("/health", func(c *gin.Context) {
|
r.GET("/health", func(c *gin.Context) {
|
||||||
c.JSON(200, gin.H{"status": "ok"})
|
c.JSON(200, gin.H{"status": "ok"})
|
||||||
})
|
})
|
||||||
|
|||||||
37
backend/cmd/server/routes_test.go
Normal file
37
backend/cmd/server/routes_test.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"nex/backend/internal/handler"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetupRoutes_Version(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
setupRoutes(r, &handler.ProxyHandler{}, &handler.ProviderHandler{}, &handler.ModelHandler{}, &handler.StatsHandler{}, handler.NewVersionHandler(), handler.NewSettingsHandler(nil, "server", false, ""))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]string
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||||
|
t.Fatalf("解析响应失败: %v", err)
|
||||||
|
}
|
||||||
|
for _, key := range []string{"version", "commit", "build_time"} {
|
||||||
|
if result[key] == "" {
|
||||||
|
t.Fatalf("响应缺少 %s 字段: %#v", key, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -24,6 +24,7 @@ require (
|
|||||||
go.uber.org/zap v1.27.1
|
go.uber.org/zap v1.27.1
|
||||||
gopkg.in/lumberjack.v2 v2.0.0
|
gopkg.in/lumberjack.v2 v2.0.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
|
gorm.io/driver/mysql v1.6.0
|
||||||
gorm.io/driver/sqlite v1.6.0
|
gorm.io/driver/sqlite v1.6.0
|
||||||
gorm.io/gorm v1.31.1
|
gorm.io/gorm v1.31.1
|
||||||
nex/embedfs v0.0.0-00010101000000-000000000000
|
nex/embedfs v0.0.0-00010101000000-000000000000
|
||||||
@@ -32,6 +33,7 @@ require (
|
|||||||
require (
|
require (
|
||||||
4d63.com/gocheckcompilerdirectives v1.3.0 // indirect
|
4d63.com/gocheckcompilerdirectives v1.3.0 // indirect
|
||||||
4d63.com/gochecknoglobals v0.2.2 // indirect
|
4d63.com/gochecknoglobals v0.2.2 // indirect
|
||||||
|
filippo.io/edwards25519 v1.2.0 // indirect
|
||||||
github.com/4meepo/tagalign v1.4.2 // indirect
|
github.com/4meepo/tagalign v1.4.2 // indirect
|
||||||
github.com/Abirdcfly/dupword v0.1.3 // indirect
|
github.com/Abirdcfly/dupword v0.1.3 // indirect
|
||||||
github.com/Antonboom/errname v1.0.0 // indirect
|
github.com/Antonboom/errname v1.0.0 // indirect
|
||||||
@@ -90,6 +92,7 @@ require (
|
|||||||
github.com/go-critic/go-critic v0.12.0 // indirect
|
github.com/go-critic/go-critic v0.12.0 // indirect
|
||||||
github.com/go-playground/locales v0.14.1 // indirect
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
|
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
||||||
github.com/go-stack/stack v1.8.0 // indirect
|
github.com/go-stack/stack v1.8.0 // indirect
|
||||||
github.com/go-toolsmith/astcast v1.1.0 // indirect
|
github.com/go-toolsmith/astcast v1.1.0 // indirect
|
||||||
github.com/go-toolsmith/astcopy v1.1.0 // indirect
|
github.com/go-toolsmith/astcopy v1.1.0 // indirect
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl
|
|||||||
cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs=
|
cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs=
|
||||||
cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
|
cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
|
||||||
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
||||||
|
filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
|
||||||
|
filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
|
||||||
github.com/4meepo/tagalign v1.4.2 h1:0hcLHPGMjDyM1gHG58cS73aQF8J4TdVR96TZViorO9E=
|
github.com/4meepo/tagalign v1.4.2 h1:0hcLHPGMjDyM1gHG58cS73aQF8J4TdVR96TZViorO9E=
|
||||||
github.com/4meepo/tagalign v1.4.2/go.mod h1:+p4aMyFM+ra7nb41CnFG6aSDXqRxU/w1VQqScKqDARI=
|
github.com/4meepo/tagalign v1.4.2/go.mod h1:+p4aMyFM+ra7nb41CnFG6aSDXqRxU/w1VQqScKqDARI=
|
||||||
github.com/Abirdcfly/dupword v0.1.3 h1:9Pa1NuAsZvpFPi9Pqkd93I7LIYRURj+A//dFd5tgBeE=
|
github.com/Abirdcfly/dupword v0.1.3 h1:9Pa1NuAsZvpFPi9Pqkd93I7LIYRURj+A//dFd5tgBeE=
|
||||||
@@ -206,6 +208,8 @@ github.com/go-playground/validator/v10 v10.30.2 h1:JiFIMtSSHb2/XBUbWM4i/MpeQm9ZK
|
|||||||
github.com/go-playground/validator/v10 v10.30.2/go.mod h1:mAf2pIOVXjTEBrwUMGKkCWKKPs9NheYGabeB04txQSc=
|
github.com/go-playground/validator/v10 v10.30.2/go.mod h1:mAf2pIOVXjTEBrwUMGKkCWKKPs9NheYGabeB04txQSc=
|
||||||
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
|
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
|
||||||
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
|
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
|
||||||
|
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
||||||
|
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||||
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
|
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
|
||||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||||
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
|
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
|
||||||
@@ -1052,6 +1056,8 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
|||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg=
|
||||||
|
gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo=
|
||||||
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||||
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||||
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
"github.com/spf13/pflag"
|
"github.com/spf13/pflag"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
|
"go.uber.org/zap"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
appErrors "nex/backend/pkg/errors"
|
appErrors "nex/backend/pkg/errors"
|
||||||
@@ -32,7 +33,13 @@ type ServerConfig struct {
|
|||||||
|
|
||||||
// DatabaseConfig 数据库配置
|
// DatabaseConfig 数据库配置
|
||||||
type DatabaseConfig struct {
|
type DatabaseConfig struct {
|
||||||
Path string `yaml:"path" mapstructure:"path" validate:"required"`
|
Driver string `yaml:"driver" mapstructure:"driver" validate:"required,oneof=sqlite mysql"`
|
||||||
|
Path string `yaml:"path" mapstructure:"path" validate:"required_if=Driver sqlite"`
|
||||||
|
Host string `yaml:"host" mapstructure:"host" validate:"required_if=Driver mysql"`
|
||||||
|
Port int `yaml:"port" mapstructure:"port" validate:"required_if=Driver mysql,omitempty,min=1,max=65535"`
|
||||||
|
User string `yaml:"user" mapstructure:"user" validate:"required_if=Driver mysql"`
|
||||||
|
Password string `yaml:"password" mapstructure:"password"`
|
||||||
|
DBName string `yaml:"dbname" mapstructure:"dbname" validate:"required_if=Driver mysql"`
|
||||||
MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" validate:"required,min=1"`
|
MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" validate:"required,min=1"`
|
||||||
MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" validate:"required,min=1"`
|
MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" validate:"required,min=1"`
|
||||||
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" mapstructure:"conn_max_lifetime" validate:"required"`
|
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" mapstructure:"conn_max_lifetime" validate:"required"`
|
||||||
@@ -51,7 +58,10 @@ type LogConfig struct {
|
|||||||
// DefaultConfig returns default config values
|
// DefaultConfig returns default config values
|
||||||
func DefaultConfig() *Config {
|
func DefaultConfig() *Config {
|
||||||
// Use home dir for default paths
|
// Use home dir for default paths
|
||||||
homeDir, _ := os.UserHomeDir()
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
homeDir = "."
|
||||||
|
}
|
||||||
nexDir := filepath.Join(homeDir, ".nex")
|
nexDir := filepath.Join(homeDir, ".nex")
|
||||||
|
|
||||||
return &Config{
|
return &Config{
|
||||||
@@ -61,7 +71,13 @@ func DefaultConfig() *Config {
|
|||||||
WriteTimeout: 30 * time.Second,
|
WriteTimeout: 30 * time.Second,
|
||||||
},
|
},
|
||||||
Database: DatabaseConfig{
|
Database: DatabaseConfig{
|
||||||
|
Driver: "sqlite",
|
||||||
Path: filepath.Join(nexDir, "config.db"),
|
Path: filepath.Join(nexDir, "config.db"),
|
||||||
|
Host: "",
|
||||||
|
Port: 3306,
|
||||||
|
User: "",
|
||||||
|
Password: "",
|
||||||
|
DBName: "nex",
|
||||||
MaxIdleConns: 10,
|
MaxIdleConns: 10,
|
||||||
MaxOpenConns: 100,
|
MaxOpenConns: 100,
|
||||||
ConnMaxLifetime: 1 * time.Hour,
|
ConnMaxLifetime: 1 * time.Hour,
|
||||||
@@ -84,7 +100,7 @@ func GetConfigDir() (string, error) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
configDir := filepath.Join(homeDir, ".nex")
|
configDir := filepath.Join(homeDir, ".nex")
|
||||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return configDir, nil
|
return configDir, nil
|
||||||
@@ -110,14 +126,23 @@ func GetConfigPath() (string, error) {
|
|||||||
|
|
||||||
// setupDefaults 设置默认配置值
|
// setupDefaults 设置默认配置值
|
||||||
func setupDefaults(v *viper.Viper) {
|
func setupDefaults(v *viper.Viper) {
|
||||||
homeDir, _ := os.UserHomeDir()
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
homeDir = "."
|
||||||
|
}
|
||||||
nexDir := filepath.Join(homeDir, ".nex")
|
nexDir := filepath.Join(homeDir, ".nex")
|
||||||
|
|
||||||
v.SetDefault("server.port", 9826)
|
v.SetDefault("server.port", 9826)
|
||||||
v.SetDefault("server.read_timeout", "30s")
|
v.SetDefault("server.read_timeout", "30s")
|
||||||
v.SetDefault("server.write_timeout", "30s")
|
v.SetDefault("server.write_timeout", "30s")
|
||||||
|
|
||||||
|
v.SetDefault("database.driver", "sqlite")
|
||||||
v.SetDefault("database.path", filepath.Join(nexDir, "config.db"))
|
v.SetDefault("database.path", filepath.Join(nexDir, "config.db"))
|
||||||
|
v.SetDefault("database.host", "")
|
||||||
|
v.SetDefault("database.port", 3306)
|
||||||
|
v.SetDefault("database.user", "")
|
||||||
|
v.SetDefault("database.password", "")
|
||||||
|
v.SetDefault("database.dbname", "nex")
|
||||||
v.SetDefault("database.max_idle_conns", 10)
|
v.SetDefault("database.max_idle_conns", 10)
|
||||||
v.SetDefault("database.max_open_conns", 100)
|
v.SetDefault("database.max_open_conns", 100)
|
||||||
v.SetDefault("database.conn_max_lifetime", "1h")
|
v.SetDefault("database.conn_max_lifetime", "1h")
|
||||||
@@ -138,7 +163,13 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
|
|||||||
flagSet.Duration("server-read-timeout", 0, "读超时")
|
flagSet.Duration("server-read-timeout", 0, "读超时")
|
||||||
flagSet.Duration("server-write-timeout", 0, "写超时")
|
flagSet.Duration("server-write-timeout", 0, "写超时")
|
||||||
|
|
||||||
|
flagSet.String("database-driver", "", "数据库驱动:sqlite/mysql")
|
||||||
flagSet.String("database-path", "", "数据库文件路径")
|
flagSet.String("database-path", "", "数据库文件路径")
|
||||||
|
flagSet.String("database-host", "", "MySQL 主机地址")
|
||||||
|
flagSet.Int("database-port", 0, "MySQL 端口")
|
||||||
|
flagSet.String("database-user", "", "MySQL 用户名")
|
||||||
|
flagSet.String("database-password", "", "MySQL 密码")
|
||||||
|
flagSet.String("database-dbname", "", "MySQL 数据库名")
|
||||||
flagSet.Int("database-max-idle-conns", 0, "最大空闲连接数")
|
flagSet.Int("database-max-idle-conns", 0, "最大空闲连接数")
|
||||||
flagSet.Int("database-max-open-conns", 0, "最大打开连接数")
|
flagSet.Int("database-max-open-conns", 0, "最大打开连接数")
|
||||||
flagSet.Duration("database-conn-max-lifetime", 0, "连接最大生命周期")
|
flagSet.Duration("database-conn-max-lifetime", 0, "连接最大生命周期")
|
||||||
@@ -152,21 +183,33 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
|
|||||||
|
|
||||||
// 绑定所有 flag 到 viper
|
// 绑定所有 flag 到 viper
|
||||||
// 注意:必须在设置默认值之后绑定
|
// 注意:必须在设置默认值之后绑定
|
||||||
v.BindPFlag("server.port", flagSet.Lookup("server-port"))
|
bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
|
||||||
v.BindPFlag("server.read_timeout", flagSet.Lookup("server-read-timeout"))
|
bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
|
||||||
v.BindPFlag("server.write_timeout", flagSet.Lookup("server-write-timeout"))
|
bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
|
||||||
|
|
||||||
v.BindPFlag("database.path", flagSet.Lookup("database-path"))
|
bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
|
||||||
v.BindPFlag("database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
|
bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
|
||||||
v.BindPFlag("database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
|
bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
|
||||||
v.BindPFlag("database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
|
bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
|
||||||
|
bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
|
||||||
|
bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
|
||||||
|
bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
|
||||||
|
bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
|
||||||
|
bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
|
||||||
|
bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
|
||||||
|
|
||||||
v.BindPFlag("log.level", flagSet.Lookup("log-level"))
|
bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
|
||||||
v.BindPFlag("log.path", flagSet.Lookup("log-path"))
|
bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
|
||||||
v.BindPFlag("log.max_size", flagSet.Lookup("log-max-size"))
|
bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
|
||||||
v.BindPFlag("log.max_backups", flagSet.Lookup("log-max-backups"))
|
bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
|
||||||
v.BindPFlag("log.max_age", flagSet.Lookup("log-max-age"))
|
bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
|
||||||
v.BindPFlag("log.compress", flagSet.Lookup("log-compress"))
|
bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
|
||||||
|
if err := v.BindPFlag(key, flag); err != nil {
|
||||||
|
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupEnv 绑定环境变量
|
// setupEnv 绑定环境变量
|
||||||
@@ -181,73 +224,156 @@ func setupConfigFile(v *viper.Viper, configPath string) error {
|
|||||||
v.SetConfigFile(configPath)
|
v.SetConfigFile(configPath)
|
||||||
v.SetConfigType("yaml")
|
v.SetConfigType("yaml")
|
||||||
|
|
||||||
// 尝试读取配置文件,如果不存在则忽略
|
|
||||||
if err := v.ReadInConfig(); err != nil {
|
if err := v.ReadInConfig(); err != nil {
|
||||||
if !os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||||
}
|
}
|
||||||
// 配置文件不存在,创建默认配置文件
|
|
||||||
if err := v.SafeWriteConfig(); err != nil {
|
|
||||||
// 忽略写入错误(可能目录已存在等)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadConfig loads config from YAML file, creates default if not exists
|
type ConfigMetadata struct {
|
||||||
func LoadConfig() (*Config, error) {
|
ConfigPath string
|
||||||
configPath, err := GetConfigPath()
|
|
||||||
if err != nil {
|
|
||||||
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
|
|
||||||
}
|
|
||||||
return LoadConfigFromPath(configPath)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadConfigFromPath 从指定路径加载配置
|
type loadOptions struct {
|
||||||
func LoadConfigFromPath(configPath string) (*Config, error) {
|
configPathOverride string
|
||||||
// 1. 创建 Viper 实例
|
useCLI bool
|
||||||
|
useEnv bool
|
||||||
|
useConfigFlag bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveConfigPath 根据 loadOptions 解析 CLI 参数并返回最终配置文件路径
|
||||||
|
func resolveConfigPath(v *viper.Viper, opts loadOptions) (string, error) {
|
||||||
|
configPath := opts.configPathOverride
|
||||||
|
|
||||||
|
if !opts.useCLI && !opts.useConfigFlag {
|
||||||
|
return configPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
flagSet := pflag.NewFlagSet("config", pflag.ContinueOnError)
|
||||||
|
if opts.useConfigFlag {
|
||||||
|
flagSet.String("config", opts.configPathOverride, "配置文件路径")
|
||||||
|
}
|
||||||
|
if opts.useCLI {
|
||||||
|
setupFlags(v, flagSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := flagSet.Parse(os.Args[1:]); err != nil {
|
||||||
|
return "", appErrors.Wrap(appErrors.ErrInvalidRequest, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.useConfigFlag {
|
||||||
|
if f, err := flagSet.GetString("config"); err == nil && f != "" {
|
||||||
|
configPath = f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return configPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadConfig(opts loadOptions) (*Config, error) {
|
||||||
|
cfg, _, err := loadConfigWithMetadata(opts)
|
||||||
|
return cfg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadConfigWithMetadata(opts loadOptions) (*Config, ConfigMetadata, error) {
|
||||||
v := viper.New()
|
v := viper.New()
|
||||||
|
|
||||||
// 2. 定义 CLI 参数
|
|
||||||
flagSet := pflag.NewFlagSet("config", pflag.ContinueOnError)
|
|
||||||
flagSet.String("config", configPath, "配置文件路径")
|
|
||||||
setupFlags(v, flagSet)
|
|
||||||
|
|
||||||
// 3. 解析 CLI 参数(忽略错误,因为可能没有参数)
|
|
||||||
flagSet.Parse(os.Args[1:])
|
|
||||||
|
|
||||||
// 4. 获取配置文件路径(可能被 --config 参数覆盖)
|
|
||||||
if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" {
|
|
||||||
configPath = configPathFlag
|
|
||||||
}
|
|
||||||
|
|
||||||
// 5. 设置默认值
|
|
||||||
setupDefaults(v)
|
setupDefaults(v)
|
||||||
|
|
||||||
// 6. 绑定环境变量
|
configPath, err := resolveConfigPath(v, opts)
|
||||||
setupEnv(v)
|
if err != nil {
|
||||||
|
return nil, ConfigMetadata{}, err
|
||||||
// 7. 读取配置文件
|
}
|
||||||
if err := setupConfigFile(v, configPath); err != nil {
|
|
||||||
return nil, err
|
if opts.useEnv {
|
||||||
|
setupEnv(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := setupConfigFile(v, configPath); err != nil {
|
||||||
|
return nil, ConfigMetadata{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 8. 反序列化到结构体
|
|
||||||
cfg := &Config{}
|
cfg := &Config{}
|
||||||
if err := v.Unmarshal(cfg, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
|
if err := v.Unmarshal(cfg, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
|
||||||
mapstructure.StringToTimeDurationHookFunc(),
|
mapstructure.StringToTimeDurationHookFunc(),
|
||||||
mapstructure.StringToSliceHookFunc(","),
|
mapstructure.StringToSliceHookFunc(","),
|
||||||
))); err != nil {
|
))); err != nil {
|
||||||
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
|
return nil, ConfigMetadata{}, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 9. 验证配置
|
|
||||||
if err := cfg.Validate(); err != nil {
|
if err := cfg.Validate(); err != nil {
|
||||||
return nil, err
|
return nil, ConfigMetadata{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return cfg, nil
|
return cfg, ConfigMetadata{ConfigPath: configPath}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadServerConfig 为 server 入口加载配置,支持 CLI 参数、环境变量和 --config
|
||||||
|
func LoadServerConfig() (*Config, error) {
|
||||||
|
cfg, _, err := LoadServerConfigWithMetadata()
|
||||||
|
return cfg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadServerConfigWithMetadata() (*Config, ConfigMetadata, error) {
|
||||||
|
configPath, err := GetConfigPath()
|
||||||
|
if err != nil {
|
||||||
|
return nil, ConfigMetadata{}, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||||
|
}
|
||||||
|
return loadConfigWithMetadata(loadOptions{
|
||||||
|
configPathOverride: configPath,
|
||||||
|
useCLI: true,
|
||||||
|
useEnv: true,
|
||||||
|
useConfigFlag: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadDesktopConfig 为 desktop 入口加载配置,固定使用默认配置文件,不支持 CLI、环境变量和 --config
|
||||||
|
func LoadDesktopConfig() (*Config, error) {
|
||||||
|
cfg, _, err := LoadDesktopConfigWithMetadata()
|
||||||
|
return cfg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadDesktopConfigWithMetadata() (*Config, ConfigMetadata, error) {
|
||||||
|
configPath, err := GetConfigPath()
|
||||||
|
if err != nil {
|
||||||
|
return nil, ConfigMetadata{}, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||||
|
}
|
||||||
|
return loadConfigWithMetadata(loadOptions{
|
||||||
|
configPathOverride: configPath,
|
||||||
|
useCLI: false,
|
||||||
|
useEnv: false,
|
||||||
|
useConfigFlag: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConfig loads config from YAML file.
|
||||||
|
// 向后兼容,等同于 LoadServerConfig。
|
||||||
|
func LoadConfig() (*Config, error) {
|
||||||
|
return LoadServerConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConfigFromPath 从指定路径加载配置。
|
||||||
|
// 保留向后兼容,沿用 server 语义(支持 CLI、env 和 --config 覆盖)。
|
||||||
|
func LoadConfigFromPath(configPath string) (*Config, error) {
|
||||||
|
return loadConfig(loadOptions{
|
||||||
|
configPathOverride: configPath,
|
||||||
|
useCLI: true,
|
||||||
|
useEnv: true,
|
||||||
|
useConfigFlag: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadDesktopConfigAtPath 从指定路径以 desktop 语义加载配置(仅配置文件和默认值),用于测试场景。
|
||||||
|
func LoadDesktopConfigAtPath(configPath string) (*Config, error) {
|
||||||
|
return loadConfig(loadOptions{
|
||||||
|
configPathOverride: configPath,
|
||||||
|
useCLI: false,
|
||||||
|
useEnv: false,
|
||||||
|
useConfigFlag: false,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveConfig saves config to YAML file
|
// SaveConfig saves config to YAML file
|
||||||
@@ -256,19 +382,21 @@ func SaveConfig(cfg *Config) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||||
}
|
}
|
||||||
|
return SaveConfigToPath(cfg, configPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveConfigToPath(cfg *Config, configPath string) error {
|
||||||
data, err := yaml.Marshal(cfg)
|
data, err := yaml.Marshal(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure directory exists
|
|
||||||
dir := filepath.Dir(configPath)
|
dir := filepath.Dir(configPath)
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return os.WriteFile(configPath, data, 0644)
|
return os.WriteFile(configPath, data, 0o600)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate validates the config
|
// Validate validates the config
|
||||||
@@ -281,16 +409,24 @@ func (c *Config) Validate() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PrintSummary 打印配置摘要
|
// PrintSummary 打印配置摘要
|
||||||
func (c *Config) PrintSummary() {
|
func (c *Config) PrintSummary(logger *zap.Logger) {
|
||||||
fmt.Println("\nAI Gateway 启动配置")
|
logger.Info("AI Gateway 启动配置",
|
||||||
fmt.Println("==================")
|
zap.Int("server_port", c.Server.Port),
|
||||||
fmt.Printf("服务器端口: %d\n", c.Server.Port)
|
zap.String("database_driver", c.Database.Driver),
|
||||||
fmt.Printf("数据库路径: %s\n", c.Database.Path)
|
zap.String("log_level", c.Log.Level),
|
||||||
fmt.Printf("日志级别: %s\n", c.Log.Level)
|
)
|
||||||
fmt.Println("\n配置来源:")
|
|
||||||
configPath, _ := GetConfigPath()
|
if c.Database.Driver == "mysql" {
|
||||||
fmt.Printf(" 配置文件: %s\n", configPath)
|
logger.Info("数据库配置",
|
||||||
fmt.Println(" 环境变量: 待统计")
|
zap.String("driver", "mysql"),
|
||||||
fmt.Println(" CLI 参数: 待统计")
|
zap.String("host", c.Database.Host),
|
||||||
fmt.Println()
|
zap.Int("port", c.Database.Port),
|
||||||
|
zap.String("database", c.Database.DBName),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
logger.Info("数据库配置",
|
||||||
|
zap.String("driver", "sqlite"),
|
||||||
|
zap.String("path", c.Database.Path),
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
151
backend/internal/config/config_metadata_test.go
Normal file
151
backend/internal/config/config_metadata_test.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadDesktopConfigAtPath_WithMetadata(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
configPath := filepath.Join(dir, "config.yaml")
|
||||||
|
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
cfg.Server.Port = 8888
|
||||||
|
data, err := yaml.Marshal(cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, os.WriteFile(configPath, data, 0o600))
|
||||||
|
|
||||||
|
loaded, meta, err := loadConfigWithMetadata(loadOptions{
|
||||||
|
configPathOverride: configPath,
|
||||||
|
useCLI: false,
|
||||||
|
useEnv: false,
|
||||||
|
useConfigFlag: false,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 8888, loaded.Server.Port)
|
||||||
|
assert.Equal(t, configPath, meta.ConfigPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveConfigToPath(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
configPath := filepath.Join(dir, "sub", "config.yaml")
|
||||||
|
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
cfg.Server.Port = 7777
|
||||||
|
|
||||||
|
err := SaveConfigToPath(cfg, configPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), "7777")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveConfigToPath_InvalidDir(t *testing.T) {
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
err := SaveConfigToPath(cfg, "/dev/null/impossible/config.yaml")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDurationConversion(t *testing.T) {
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
dto := configToDTO(cfg)
|
||||||
|
|
||||||
|
parsed, err := time.ParseDuration(dto.Server.ReadTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, cfg.Server.ReadTimeout, parsed)
|
||||||
|
|
||||||
|
parsed, err = time.ParseDuration(dto.Database.ConnMaxLifetime)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, cfg.Database.ConnMaxLifetime, parsed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveConfigToPath_DurationFormat(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
configPath := filepath.Join(dir, "config.yaml")
|
||||||
|
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
cfg.Server.ReadTimeout = 30 * time.Second
|
||||||
|
cfg.Server.WriteTimeout = 1 * time.Minute
|
||||||
|
cfg.Database.ConnMaxLifetime = 1 * time.Hour
|
||||||
|
|
||||||
|
err := SaveConfigToPath(cfg, configPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
content := string(data)
|
||||||
|
assert.Contains(t, content, "conn_max_lifetime: 1h0m0s")
|
||||||
|
assert.Contains(t, content, "read_timeout: 30s")
|
||||||
|
assert.Contains(t, content, "write_timeout: 1m0s")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveAndReload_DurationRoundTrip(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
configPath := filepath.Join(dir, "config.yaml")
|
||||||
|
|
||||||
|
yamlContent := `
|
||||||
|
server:
|
||||||
|
port: 9826
|
||||||
|
read_timeout: 30s
|
||||||
|
write_timeout: 1m
|
||||||
|
database:
|
||||||
|
driver: sqlite
|
||||||
|
path: ` + filepath.Join(dir, "test.db") + `
|
||||||
|
max_idle_conns: 10
|
||||||
|
max_open_conns: 100
|
||||||
|
conn_max_lifetime: 30m
|
||||||
|
log:
|
||||||
|
level: info
|
||||||
|
path: ` + filepath.Join(dir, "log") + `
|
||||||
|
max_size: 100
|
||||||
|
max_backups: 10
|
||||||
|
max_age: 30
|
||||||
|
compress: true
|
||||||
|
`
|
||||||
|
require.NoError(t, os.WriteFile(configPath, []byte(yamlContent), 0o600))
|
||||||
|
|
||||||
|
cfg, err := LoadDesktopConfigAtPath(configPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 30*time.Minute, cfg.Database.ConnMaxLifetime)
|
||||||
|
|
||||||
|
err = SaveConfigToPath(cfg, configPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), "conn_max_lifetime: 30m0s")
|
||||||
|
}
|
||||||
|
|
||||||
|
func configToDTO(c *Config) struct {
|
||||||
|
Server struct {
|
||||||
|
Port int `json:"port"`
|
||||||
|
ReadTimeout string `json:"read_timeout"`
|
||||||
|
WriteTimeout string `json:"write_timeout"`
|
||||||
|
}
|
||||||
|
Database struct {
|
||||||
|
ConnMaxLifetime string `json:"conn_max_lifetime"`
|
||||||
|
}
|
||||||
|
} {
|
||||||
|
var result struct {
|
||||||
|
Server struct {
|
||||||
|
Port int `json:"port"`
|
||||||
|
ReadTimeout string `json:"read_timeout"`
|
||||||
|
WriteTimeout string `json:"write_timeout"`
|
||||||
|
}
|
||||||
|
Database struct {
|
||||||
|
ConnMaxLifetime string `json:"conn_max_lifetime"`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.Server.Port = c.Server.Port
|
||||||
|
result.Server.ReadTimeout = c.Server.ReadTimeout.String()
|
||||||
|
result.Server.WriteTimeout = c.Server.WriteTimeout.String()
|
||||||
|
result.Database.ConnMaxLifetime = c.Database.ConnMaxLifetime.String()
|
||||||
|
return result
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/zap"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,6 +20,12 @@ func TestDefaultConfig(t *testing.T) {
|
|||||||
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
|
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
|
||||||
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
|
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
|
||||||
|
|
||||||
|
assert.Equal(t, "sqlite", cfg.Database.Driver)
|
||||||
|
assert.Equal(t, "", cfg.Database.Host)
|
||||||
|
assert.Equal(t, 3306, cfg.Database.Port)
|
||||||
|
assert.Equal(t, "", cfg.Database.User)
|
||||||
|
assert.Equal(t, "", cfg.Database.Password)
|
||||||
|
assert.Equal(t, "nex", cfg.Database.DBName)
|
||||||
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
|
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
|
||||||
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
|
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
|
||||||
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
|
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
|
||||||
@@ -86,11 +93,76 @@ func TestConfig_Validate(t *testing.T) {
|
|||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "数据库路径为空无效",
|
name: "SQLite模式路径为空无效",
|
||||||
modify: func(c *Config) { c.Database.Path = "" },
|
modify: func(c *Config) { c.Database.Path = "" },
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
errMsg: "配置验证失败",
|
errMsg: "配置验证失败",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "driver值不合法",
|
||||||
|
modify: func(c *Config) { c.Database.Driver = "postgres" },
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "配置验证失败",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MySQL配置有效",
|
||||||
|
modify: func(c *Config) {
|
||||||
|
c.Database.Driver = "mysql"
|
||||||
|
c.Database.Host = "localhost"
|
||||||
|
c.Database.Port = 3306
|
||||||
|
c.Database.User = "root"
|
||||||
|
c.Database.DBName = "nex"
|
||||||
|
c.Database.Path = ""
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MySQL模式host为空无效",
|
||||||
|
modify: func(c *Config) {
|
||||||
|
c.Database.Driver = "mysql"
|
||||||
|
c.Database.Host = ""
|
||||||
|
c.Database.User = "root"
|
||||||
|
c.Database.DBName = "nex"
|
||||||
|
c.Database.Path = ""
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "配置验证失败",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MySQL模式user为空无效",
|
||||||
|
modify: func(c *Config) {
|
||||||
|
c.Database.Driver = "mysql"
|
||||||
|
c.Database.Host = "localhost"
|
||||||
|
c.Database.User = ""
|
||||||
|
c.Database.DBName = "nex"
|
||||||
|
c.Database.Path = ""
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "配置验证失败",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MySQL模式dbname为空无效",
|
||||||
|
modify: func(c *Config) {
|
||||||
|
c.Database.Driver = "mysql"
|
||||||
|
c.Database.Host = "localhost"
|
||||||
|
c.Database.User = "root"
|
||||||
|
c.Database.DBName = ""
|
||||||
|
c.Database.Path = ""
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "配置验证失败",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MySQL模式忽略path字段",
|
||||||
|
modify: func(c *Config) {
|
||||||
|
c.Database.Driver = "mysql"
|
||||||
|
c.Database.Host = "localhost"
|
||||||
|
c.Database.User = "root"
|
||||||
|
c.Database.DBName = "nex"
|
||||||
|
c.Database.Path = ""
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -100,7 +172,9 @@ func TestConfig_Validate(t *testing.T) {
|
|||||||
err := cfg.Validate()
|
err := cfg.Validate()
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
if err != nil {
|
||||||
assert.Contains(t, err.Error(), tt.errMsg)
|
assert.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
@@ -140,7 +214,10 @@ func TestSaveAndLoadConfig(t *testing.T) {
|
|||||||
WriteTimeout: 20 * time.Second,
|
WriteTimeout: 20 * time.Second,
|
||||||
},
|
},
|
||||||
Database: DatabaseConfig{
|
Database: DatabaseConfig{
|
||||||
|
Driver: "sqlite",
|
||||||
Path: filepath.Join(dir, "test.db"),
|
Path: filepath.Join(dir, "test.db"),
|
||||||
|
Port: 3306,
|
||||||
|
DBName: "nex",
|
||||||
MaxIdleConns: 5,
|
MaxIdleConns: 5,
|
||||||
MaxOpenConns: 50,
|
MaxOpenConns: 50,
|
||||||
ConnMaxLifetime: 30 * time.Minute,
|
ConnMaxLifetime: 30 * time.Minute,
|
||||||
@@ -159,7 +236,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
|
|||||||
configPath := filepath.Join(dir, "config.yaml")
|
configPath := filepath.Join(dir, "config.yaml")
|
||||||
data, err := yaml.Marshal(cfg)
|
data, err := yaml.Marshal(cfg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = os.WriteFile(configPath, data, 0644)
|
err = os.WriteFile(configPath, data, 0o600)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// 加载配置
|
// 加载配置
|
||||||
@@ -210,6 +287,9 @@ func TestConfigPriority(t *testing.T) {
|
|||||||
assert.Equal(t, 9826, cfg.Server.Port)
|
assert.Equal(t, 9826, cfg.Server.Port)
|
||||||
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
|
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
|
||||||
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
|
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
|
||||||
|
assert.Equal(t, "sqlite", cfg.Database.Driver)
|
||||||
|
assert.Equal(t, 3306, cfg.Database.Port)
|
||||||
|
assert.Equal(t, "nex", cfg.Database.DBName)
|
||||||
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
|
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
|
||||||
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
|
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
|
||||||
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
|
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
|
||||||
@@ -222,13 +302,21 @@ func TestConfigPriority(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPrintSummary(t *testing.T) {
|
func TestPrintSummary(t *testing.T) {
|
||||||
// 测试配置摘要输出
|
t.Run("SQLite模式摘要", func(t *testing.T) {
|
||||||
t.Run("打印配置摘要", func(t *testing.T) {
|
|
||||||
cfg := DefaultConfig()
|
cfg := DefaultConfig()
|
||||||
// PrintSummary 只是打印,不会返回错误
|
|
||||||
// 这里主要验证不会 panic
|
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
cfg.PrintSummary()
|
cfg.PrintSummary(zap.NewNop())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("MySQL模式摘要", func(t *testing.T) {
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
cfg.Database.Driver = "mysql"
|
||||||
|
cfg.Database.Host = "db.example.com"
|
||||||
|
cfg.Database.Port = 3306
|
||||||
|
cfg.Database.User = "nex"
|
||||||
|
cfg.Database.DBName = "nex"
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
cfg.PrintSummary(zap.NewNop())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,8 +29,8 @@ type Model struct {
|
|||||||
// UsageStats 用量统计
|
// UsageStats 用量统计
|
||||||
type UsageStats struct {
|
type UsageStats struct {
|
||||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||||
ProviderID string `gorm:"not null;index" json:"provider_id"`
|
ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"provider_id"`
|
||||||
ModelName string `gorm:"not null;index" json:"model_name"`
|
ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"model_name"`
|
||||||
RequestCount int `gorm:"default:0" json:"request_count"`
|
RequestCount int `gorm:"default:0" json:"request_count"`
|
||||||
Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"`
|
Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"`
|
||||||
}
|
}
|
||||||
@@ -47,4 +47,3 @@ func (Model) TableName() string {
|
|||||||
func (UsageStats) TableName() string {
|
func (UsageStats) TableName() string {
|
||||||
return "usage_stats"
|
return "usage_stats"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -141,7 +141,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
|||||||
Message: err.Message,
|
Message: err.Message,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
body, _ := json.Marshal(errMsg)
|
body, marshalErr := json.Marshal(errMsg)
|
||||||
|
if marshalErr != nil {
|
||||||
|
return []byte(`{"type":"error","error":{"type":"internal_error","message":"internal error"}}`), statusCode
|
||||||
|
}
|
||||||
return body, statusCode
|
return body, statusCode
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,7 +238,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
|
|||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
rewriteFunc := func(newModel string) ([]byte, error) {
|
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||||
m["model"], _ = json.Marshal(newModel)
|
encodedModel, err := json.Marshal(newModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m["model"] = encodedModel
|
||||||
return json.Marshal(m)
|
return json.Marshal(m)
|
||||||
}
|
}
|
||||||
return current, rewriteFunc, nil
|
return current, rewriteFunc, nil
|
||||||
@@ -269,7 +276,11 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
|
|||||||
switch ifaceType {
|
switch ifaceType {
|
||||||
case conversion.InterfaceTypeChat:
|
case conversion.InterfaceTypeChat:
|
||||||
// Chat 响应必须有 model 字段,存在则改写,不存在则添加
|
// Chat 响应必须有 model 字段,存在则改写,不存在则添加
|
||||||
m["model"], _ = json.Marshal(newModel)
|
encodedModel, err := json.Marshal(newModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m["model"] = encodedModel
|
||||||
return json.Marshal(m)
|
return json.Marshal(m)
|
||||||
default:
|
default:
|
||||||
return body, nil
|
return body, nil
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package anthropic
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"nex/backend/internal/conversion"
|
"nex/backend/internal/conversion"
|
||||||
@@ -48,6 +49,28 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdapter_APIReferenceNativePaths(t *testing.T) {
|
||||||
|
a := NewAdapter()
|
||||||
|
|
||||||
|
// docs/api_reference/anthropic defines messages and models under /v1.
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
expected conversion.InterfaceType
|
||||||
|
}{
|
||||||
|
{"/v1/messages", conversion.InterfaceTypeChat},
|
||||||
|
{"/v1/models", conversion.InterfaceTypeModels},
|
||||||
|
{"/v1/models/claude-sonnet-4-5", conversion.InterfaceTypeModelInfo},
|
||||||
|
{"/messages", conversion.InterfaceTypePassthrough},
|
||||||
|
{"/models", conversion.InterfaceTypePassthrough},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.path, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAdapter_BuildUrl(t *testing.T) {
|
func TestAdapter_BuildUrl(t *testing.T) {
|
||||||
a := NewAdapter()
|
a := NewAdapter()
|
||||||
|
|
||||||
@@ -141,8 +164,8 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
|
|||||||
t.Run("解码嵌入请求", func(t *testing.T) {
|
t.Run("解码嵌入请求", func(t *testing.T) {
|
||||||
_, err := a.DecodeEmbeddingRequest([]byte(`{}`))
|
_, err := a.DecodeEmbeddingRequest([]byte(`{}`))
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
convErr, ok := err.(*conversion.ConversionError)
|
var convErr *conversion.ConversionError
|
||||||
require.True(t, ok)
|
require.ErrorAs(t, err, &convErr)
|
||||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -150,24 +173,24 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
|
|||||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||||
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
|
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
convErr, ok := err.(*conversion.ConversionError)
|
var convErr *conversion.ConversionError
|
||||||
require.True(t, ok)
|
require.True(t, errors.As(err, &convErr))
|
||||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("解码嵌入响应", func(t *testing.T) {
|
t.Run("解码嵌入响应", func(t *testing.T) {
|
||||||
_, err := a.DecodeEmbeddingResponse([]byte(`{}`))
|
_, err := a.DecodeEmbeddingResponse([]byte(`{}`))
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
convErr, ok := err.(*conversion.ConversionError)
|
var convErr *conversion.ConversionError
|
||||||
require.True(t, ok)
|
require.ErrorAs(t, err, &convErr)
|
||||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("编码嵌入响应", func(t *testing.T) {
|
t.Run("编码嵌入响应", func(t *testing.T) {
|
||||||
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
|
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
convErr, ok := err.(*conversion.ConversionError)
|
var convErr *conversion.ConversionError
|
||||||
require.True(t, ok)
|
require.ErrorAs(t, err, &convErr)
|
||||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -178,8 +201,8 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
|
|||||||
t.Run("解码重排序请求", func(t *testing.T) {
|
t.Run("解码重排序请求", func(t *testing.T) {
|
||||||
_, err := a.DecodeRerankRequest([]byte(`{}`))
|
_, err := a.DecodeRerankRequest([]byte(`{}`))
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
convErr, ok := err.(*conversion.ConversionError)
|
var convErr *conversion.ConversionError
|
||||||
require.True(t, ok)
|
require.ErrorAs(t, err, &convErr)
|
||||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -187,24 +210,24 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
|
|||||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||||
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
|
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
convErr, ok := err.(*conversion.ConversionError)
|
var convErr *conversion.ConversionError
|
||||||
require.True(t, ok)
|
require.ErrorAs(t, err, &convErr)
|
||||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("解码重排序响应", func(t *testing.T) {
|
t.Run("解码重排序响应", func(t *testing.T) {
|
||||||
_, err := a.DecodeRerankResponse([]byte(`{}`))
|
_, err := a.DecodeRerankResponse([]byte(`{}`))
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
convErr, ok := err.(*conversion.ConversionError)
|
var convErr *conversion.ConversionError
|
||||||
require.True(t, ok)
|
require.ErrorAs(t, err, &convErr)
|
||||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("编码重排序响应", func(t *testing.T) {
|
t.Run("编码重排序响应", func(t *testing.T) {
|
||||||
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
|
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
convErr, ok := err.(*conversion.ConversionError)
|
var convErr *conversion.ConversionError
|
||||||
require.True(t, ok)
|
require.ErrorAs(t, err, &convErr)
|
||||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,7 +28,10 @@ func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
|
|||||||
|
|
||||||
var canonicalMsgs []canonical.CanonicalMessage
|
var canonicalMsgs []canonical.CanonicalMessage
|
||||||
for _, msg := range req.Messages {
|
for _, msg := range req.Messages {
|
||||||
decoded := decodeMessage(msg)
|
decoded, err := decodeMessage(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析消息内容失败").WithCause(err)
|
||||||
|
}
|
||||||
canonicalMsgs = append(canonicalMsgs, decoded...)
|
canonicalMsgs = append(canonicalMsgs, decoded...)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,10 +97,13 @@ func decodeSystem(system any) any {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// decodeMessage 解码 Anthropic 消息
|
// decodeMessage 解码 Anthropic 消息
|
||||||
func decodeMessage(msg Message) []canonical.CanonicalMessage {
|
func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
|
||||||
switch msg.Role {
|
switch msg.Role {
|
||||||
case "user":
|
case "user":
|
||||||
blocks := decodeContentBlocks(msg.Content)
|
blocks, err := decodeContentBlocks(msg.Content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
var toolResults []canonical.ContentBlock
|
var toolResults []canonical.ContentBlock
|
||||||
var others []canonical.ContentBlock
|
var others []canonical.ContentBlock
|
||||||
for _, b := range blocks {
|
for _, b := range blocks {
|
||||||
@@ -117,58 +123,83 @@ func decodeMessage(msg Message) []canonical.CanonicalMessage {
|
|||||||
if len(result) == 0 {
|
if len(result) == 0 {
|
||||||
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
|
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
|
||||||
}
|
}
|
||||||
return result
|
return result, nil
|
||||||
|
|
||||||
case "assistant":
|
case "assistant":
|
||||||
blocks := decodeContentBlocks(msg.Content)
|
blocks, err := decodeContentBlocks(msg.Content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if len(blocks) == 0 {
|
if len(blocks) == 0 {
|
||||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||||
}
|
}
|
||||||
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}
|
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
|
||||||
}
|
}
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// decodeContentBlocks 解码内容块列表
|
// decodeContentBlocks 解码内容块列表
|
||||||
func decodeContentBlocks(content any) []canonical.ContentBlock {
|
func decodeContentBlocks(content any) ([]canonical.ContentBlock, error) {
|
||||||
switch v := content.(type) {
|
switch v := content.(type) {
|
||||||
case string:
|
case string:
|
||||||
return []canonical.ContentBlock{canonical.NewTextBlock(v)}
|
return []canonical.ContentBlock{canonical.NewTextBlock(v)}, nil
|
||||||
case []any:
|
case []any:
|
||||||
var blocks []canonical.ContentBlock
|
var blocks []canonical.ContentBlock
|
||||||
for _, item := range v {
|
for _, item := range v {
|
||||||
if m, ok := item.(map[string]any); ok {
|
if m, ok := item.(map[string]any); ok {
|
||||||
block := decodeSingleContentBlock(m)
|
block, err := decodeSingleContentBlock(m)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if block != nil {
|
if block != nil {
|
||||||
blocks = append(blocks, *block)
|
blocks = append(blocks, *block)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(blocks) > 0 {
|
if len(blocks) > 0 {
|
||||||
return blocks
|
return blocks, nil
|
||||||
}
|
}
|
||||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
|
||||||
case nil:
|
case nil:
|
||||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
|
||||||
default:
|
default:
|
||||||
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}
|
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// decodeSingleContentBlock 解码单个内容块
|
// decodeSingleContentBlock 解码单个内容块
|
||||||
func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
func decodeSingleContentBlock(m map[string]any) (*canonical.ContentBlock, error) {
|
||||||
t, _ := m["type"].(string)
|
t, ok := m["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
switch t {
|
switch t {
|
||||||
case "text":
|
case "text":
|
||||||
text, _ := m["text"].(string)
|
text, ok := m["text"].(string)
|
||||||
return &canonical.ContentBlock{Type: "text", Text: text}
|
if !ok {
|
||||||
|
text = ""
|
||||||
|
}
|
||||||
|
return &canonical.ContentBlock{Type: "text", Text: text}, nil
|
||||||
case "tool_use":
|
case "tool_use":
|
||||||
id, _ := m["id"].(string)
|
id, ok := m["id"].(string)
|
||||||
name, _ := m["name"].(string)
|
if !ok {
|
||||||
input, _ := json.Marshal(m["input"])
|
id = ""
|
||||||
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}
|
}
|
||||||
|
name, ok := m["name"].(string)
|
||||||
|
if !ok {
|
||||||
|
name = ""
|
||||||
|
}
|
||||||
|
input, err := json.Marshal(m["input"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}, nil
|
||||||
case "tool_result":
|
case "tool_result":
|
||||||
toolUseID, _ := m["tool_use_id"].(string)
|
toolUseID, ok := m["tool_use_id"].(string)
|
||||||
|
if !ok {
|
||||||
|
toolUseID = ""
|
||||||
|
}
|
||||||
isErr := false
|
isErr := false
|
||||||
if ie, ok := m["is_error"].(bool); ok {
|
if ie, ok := m["is_error"].(bool); ok {
|
||||||
isErr = ie
|
isErr = ie
|
||||||
@@ -179,7 +210,11 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
|||||||
case string:
|
case string:
|
||||||
content = json.RawMessage(fmt.Sprintf("%q", cv))
|
content = json.RawMessage(fmt.Sprintf("%q", cv))
|
||||||
default:
|
default:
|
||||||
content, _ = json.Marshal(cv)
|
encoded, err := json.Marshal(cv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
content = encoded
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
content = json.RawMessage(`""`)
|
content = json.RawMessage(`""`)
|
||||||
@@ -189,15 +224,18 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
|||||||
ToolUseID: toolUseID,
|
ToolUseID: toolUseID,
|
||||||
Content: content,
|
Content: content,
|
||||||
IsError: &isErr,
|
IsError: &isErr,
|
||||||
}
|
}, nil
|
||||||
case "thinking":
|
case "thinking":
|
||||||
thinking, _ := m["thinking"].(string)
|
thinking, ok := m["thinking"].(string)
|
||||||
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}
|
if !ok {
|
||||||
|
thinking = ""
|
||||||
|
}
|
||||||
|
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}, nil
|
||||||
case "redacted_thinking":
|
case "redacted_thinking":
|
||||||
// 丢弃
|
// 丢弃
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// decodeTools 解码工具定义
|
// decodeTools 解码工具定义
|
||||||
@@ -232,7 +270,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
|||||||
return canonical.NewToolChoiceAny()
|
return canonical.NewToolChoiceAny()
|
||||||
}
|
}
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
t, _ := v["type"].(string)
|
t, ok := v["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
switch t {
|
switch t {
|
||||||
case "auto":
|
case "auto":
|
||||||
return canonical.NewToolChoiceAuto()
|
return canonical.NewToolChoiceAuto()
|
||||||
@@ -241,7 +282,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
|||||||
case "any":
|
case "any":
|
||||||
return canonical.NewToolChoiceAny()
|
return canonical.NewToolChoiceAny()
|
||||||
case "tool":
|
case "tool":
|
||||||
name, _ := v["name"].(string)
|
name, ok := v["name"].(string)
|
||||||
|
if !ok {
|
||||||
|
name = ""
|
||||||
|
}
|
||||||
return canonical.NewToolChoiceNamed(name)
|
return canonical.NewToolChoiceNamed(name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,7 +33,8 @@ func TestEncodeRequest_Basic(t *testing.T) {
|
|||||||
assert.Equal(t, true, result["stream"])
|
assert.Equal(t, true, result["stream"])
|
||||||
assert.Equal(t, float64(1024), result["max_tokens"])
|
assert.Equal(t, float64(1024), result["max_tokens"])
|
||||||
|
|
||||||
msgs := result["messages"].([]any)
|
msgs, ok := result["messages"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Len(t, msgs, 1)
|
assert.Len(t, msgs, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,17 +56,20 @@ func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
|
|||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
msgs := result["messages"].([]any)
|
msgs, ok := result["messages"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
// tool 消息应被合并到相邻 user 消息
|
// tool 消息应被合并到相邻 user 消息
|
||||||
foundToolResult := false
|
foundToolResult := false
|
||||||
for _, m := range msgs {
|
for _, m := range msgs {
|
||||||
msgMap := m.(map[string]any)
|
msgMap, ok := m.(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
if msgMap["role"] == "user" {
|
if msgMap["role"] == "user" {
|
||||||
content, ok := msgMap["content"].([]any)
|
content, ok := msgMap["content"].([]any)
|
||||||
if ok {
|
if ok {
|
||||||
for _, c := range content {
|
for _, c := range content {
|
||||||
block := c.(map[string]any)
|
block, ok := c.(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
if block["type"] == "tool_result" {
|
if block["type"] == "tool_result" {
|
||||||
foundToolResult = true
|
foundToolResult = true
|
||||||
}
|
}
|
||||||
@@ -93,8 +97,10 @@ func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
|
|||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
msgs := result["messages"].([]any)
|
msgs, ok := result["messages"].([]any)
|
||||||
firstMsg := msgs[0].(map[string]any)
|
require.True(t, ok)
|
||||||
|
firstMsg, ok := msgs[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Equal(t, "user", firstMsg["role"])
|
assert.Equal(t, "user", firstMsg["role"])
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,9 +146,11 @@ func TestEncodeResponse_Basic(t *testing.T) {
|
|||||||
assert.Equal(t, "assistant", result["role"])
|
assert.Equal(t, "assistant", result["role"])
|
||||||
assert.Equal(t, "end_turn", result["stop_reason"])
|
assert.Equal(t, "end_turn", result["stop_reason"])
|
||||||
|
|
||||||
content := result["content"].([]any)
|
content, ok := result["content"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Len(t, content, 1)
|
assert.Len(t, content, 1)
|
||||||
block := content[0].(map[string]any)
|
block, ok := content[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Equal(t, "text", block["type"])
|
assert.Equal(t, "text", block["type"])
|
||||||
assert.Equal(t, "你好", block["text"])
|
assert.Equal(t, "你好", block["text"])
|
||||||
}
|
}
|
||||||
@@ -160,10 +168,12 @@ func TestEncodeModelsResponse(t *testing.T) {
|
|||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
data := result["data"].([]any)
|
data, ok := result["data"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Len(t, data, 1)
|
assert.Len(t, data, 1)
|
||||||
|
|
||||||
model := data[0].(map[string]any)
|
model, ok := data[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Equal(t, "claude-3-opus", model["id"])
|
assert.Equal(t, "claude-3-opus", model["id"])
|
||||||
// created 应为 RFC3339 格式
|
// created 应为 RFC3339 格式
|
||||||
createdAt, ok := model["created_at"].(string)
|
createdAt, ok := model["created_at"].(string)
|
||||||
@@ -280,11 +290,14 @@ func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
|
|||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
msgs := result["messages"].([]any)
|
msgs, ok := result["messages"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Len(t, msgs, 1)
|
assert.Len(t, msgs, 1)
|
||||||
userMsg := msgs[0].(map[string]any)
|
userMsg, ok := msgs[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Equal(t, "user", userMsg["role"])
|
assert.Equal(t, "user", userMsg["role"])
|
||||||
content := userMsg["content"].([]any)
|
content, ok := userMsg["content"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Len(t, content, 2)
|
assert.Len(t, content, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -321,7 +334,8 @@ func TestEncodeResponse_ReasoningTokens(t *testing.T) {
|
|||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
usage := result["usage"].(map[string]any)
|
usage, ok := result["usage"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
_, hasReasoning := usage["reasoning_tokens"]
|
_, hasReasoning := usage["reasoning_tokens"]
|
||||||
assert.False(t, hasReasoning)
|
assert.False(t, hasReasoning)
|
||||||
}
|
}
|
||||||
@@ -341,9 +355,11 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
|
|||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
content := result["content"].([]any)
|
content, ok := result["content"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Len(t, content, 1)
|
assert.Len(t, content, 1)
|
||||||
block := content[0].(map[string]any)
|
block, ok := content[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Equal(t, "tool_use", block["type"])
|
assert.Equal(t, "tool_use", block["type"])
|
||||||
assert.Equal(t, "tool_1", block["id"])
|
assert.Equal(t, "tool_1", block["id"])
|
||||||
assert.Equal(t, "search", block["name"])
|
assert.Equal(t, "search", block["name"])
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func NewStreamDecoder() *StreamDecoder {
|
|||||||
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||||
data := rawChunk
|
data := rawChunk
|
||||||
if len(d.utf8Remainder) > 0 {
|
if len(d.utf8Remainder) > 0 {
|
||||||
data = append(d.utf8Remainder, rawChunk...)
|
data = append(append([]byte{}, d.utf8Remainder...), rawChunk...)
|
||||||
d.utf8Remainder = nil
|
d.utf8Remainder = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,9 +50,10 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
|
|||||||
|
|
||||||
for _, line := range strings.Split(text, "\n") {
|
for _, line := range strings.Split(text, "\n") {
|
||||||
line = strings.TrimRight(line, "\r")
|
line = strings.TrimRight(line, "\r")
|
||||||
if strings.HasPrefix(line, "event: ") {
|
switch {
|
||||||
|
case strings.HasPrefix(line, "event: "):
|
||||||
eventType = strings.TrimPrefix(line, "event: ")
|
eventType = strings.TrimPrefix(line, "event: ")
|
||||||
} else if strings.HasPrefix(line, "data: ") {
|
case strings.HasPrefix(line, "data: "):
|
||||||
eventData = strings.TrimPrefix(line, "data: ")
|
eventData = strings.TrimPrefix(line, "data: ")
|
||||||
if eventType != "" && eventData != "" {
|
if eventType != "" && eventData != "" {
|
||||||
chunkEvents := d.processEvent(eventType, []byte(eventData))
|
chunkEvents := d.processEvent(eventType, []byte(eventData))
|
||||||
@@ -60,8 +61,8 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
|
|||||||
}
|
}
|
||||||
eventType = ""
|
eventType = ""
|
||||||
eventData = ""
|
eventData = ""
|
||||||
} else if line == "" {
|
case line == "":
|
||||||
// SSE 事件分隔符
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -51,15 +51,23 @@ func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent)
|
|||||||
if event.Message != nil {
|
if event.Message != nil {
|
||||||
msg := map[string]any{
|
msg := map[string]any{
|
||||||
"id": event.Message.ID,
|
"id": event.Message.ID,
|
||||||
"model": event.Message.Model,
|
"type": "message",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
|
"content": []any{},
|
||||||
|
"model": event.Message.Model,
|
||||||
|
"stop_reason": nil,
|
||||||
|
"stop_sequence": nil,
|
||||||
}
|
}
|
||||||
if event.Message.Usage != nil {
|
if event.Message.Usage != nil {
|
||||||
usage := map[string]any{
|
msg["usage"] = map[string]any{
|
||||||
"input_tokens": event.Message.Usage.InputTokens,
|
"input_tokens": event.Message.Usage.InputTokens,
|
||||||
"output_tokens": event.Message.Usage.OutputTokens,
|
"output_tokens": event.Message.Usage.OutputTokens,
|
||||||
}
|
}
|
||||||
msg["usage"] = usage
|
} else {
|
||||||
|
msg["usage"] = map[string]any{
|
||||||
|
"input_tokens": 0,
|
||||||
|
"output_tokens": 0,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
payload["message"] = msg
|
payload["message"] = msg
|
||||||
}
|
}
|
||||||
@@ -147,6 +155,10 @@ func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent)
|
|||||||
payload["usage"] = map[string]any{
|
payload["usage"] = map[string]any{
|
||||||
"output_tokens": event.Usage.OutputTokens,
|
"output_tokens": event.Usage.OutputTokens,
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
payload["usage"] = map[string]any{
|
||||||
|
"output_tokens": 0,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return e.marshalEvent("message_delta", payload)
|
return e.marshalEvent("message_delta", payload)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,8 +21,55 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
|
|||||||
s := string(chunks[0])
|
s := string(chunks[0])
|
||||||
assert.True(t, strings.HasPrefix(s, "event: message_start\n"))
|
assert.True(t, strings.HasPrefix(s, "event: message_start\n"))
|
||||||
assert.Contains(t, s, "data: ")
|
assert.Contains(t, s, "data: ")
|
||||||
assert.Contains(t, s, "msg_1")
|
|
||||||
assert.Contains(t, s, "claude-3")
|
var payload map[string]any
|
||||||
|
lines := strings.Split(s, "\n")
|
||||||
|
for _, l := range lines {
|
||||||
|
if strings.HasPrefix(l, "data: ") {
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, ok := payload["message"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "msg_1", msg["id"])
|
||||||
|
assert.Equal(t, "message", msg["type"])
|
||||||
|
assert.Equal(t, "assistant", msg["role"])
|
||||||
|
assert.Equal(t, []any{}, msg["content"])
|
||||||
|
assert.Equal(t, "claude-3", msg["model"])
|
||||||
|
assert.Nil(t, msg["stop_reason"])
|
||||||
|
assert.Nil(t, msg["stop_sequence"])
|
||||||
|
|
||||||
|
usage, okU := msg["usage"].(map[string]any)
|
||||||
|
require.True(t, okU)
|
||||||
|
assert.Equal(t, float64(0), usage["input_tokens"])
|
||||||
|
assert.Equal(t, float64(0), usage["output_tokens"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamEncoder_MessageStart_WithUsage(t *testing.T) {
|
||||||
|
e := NewStreamEncoder()
|
||||||
|
event := canonical.NewMessageStartEventWithUsage("msg_2", "gpt-4", &canonical.CanonicalUsage{InputTokens: 100, OutputTokens: 50})
|
||||||
|
|
||||||
|
chunks := e.EncodeEvent(event)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
|
||||||
|
s := string(chunks[0])
|
||||||
|
var payload map[string]any
|
||||||
|
lines := strings.Split(s, "\n")
|
||||||
|
for _, l := range lines {
|
||||||
|
if strings.HasPrefix(l, "data: ") {
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, ok := payload["message"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
usage, okU := msg["usage"].(map[string]any)
|
||||||
|
require.True(t, okU)
|
||||||
|
assert.Equal(t, float64(100), usage["input_tokens"])
|
||||||
|
assert.Equal(t, float64(50), usage["output_tokens"])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamEncoder_ContentBlockDelta(t *testing.T) {
|
func TestStreamEncoder_ContentBlockDelta(t *testing.T) {
|
||||||
@@ -80,7 +127,8 @@ func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cb := payload["content_block"].(map[string]any)
|
cb, ok := payload["content_block"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Equal(t, "text", cb["type"])
|
assert.Equal(t, "text", cb["type"])
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,7 +155,8 @@ func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cb := payload["content_block"].(map[string]any)
|
cb, ok := payload["content_block"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Equal(t, "tool_use", cb["type"])
|
assert.Equal(t, "tool_use", cb["type"])
|
||||||
assert.Equal(t, "toolu_1", cb["id"])
|
assert.Equal(t, "toolu_1", cb["id"])
|
||||||
assert.Equal(t, "search", cb["name"])
|
assert.Equal(t, "search", cb["name"])
|
||||||
@@ -131,7 +180,8 @@ func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cb := payload["content_block"].(map[string]any)
|
cb, ok := payload["content_block"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Equal(t, "thinking", cb["type"])
|
assert.Equal(t, "thinking", cb["type"])
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -173,8 +223,13 @@ func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delta := payload["delta"].(map[string]any)
|
delta, okd := payload["delta"].(map[string]any)
|
||||||
|
require.True(t, okd)
|
||||||
assert.Equal(t, "end_turn", delta["stop_reason"])
|
assert.Equal(t, "end_turn", delta["stop_reason"])
|
||||||
|
|
||||||
|
usage, oku := payload["usage"].(map[string]any)
|
||||||
|
require.True(t, oku, "message_delta SHALL always include usage")
|
||||||
|
assert.Equal(t, float64(0), usage["output_tokens"])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
|
func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
|
||||||
@@ -199,7 +254,8 @@ func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
u := payload["usage"].(map[string]any)
|
u, oku := payload["usage"].(map[string]any)
|
||||||
|
require.True(t, oku)
|
||||||
assert.Equal(t, float64(88), u["output_tokens"])
|
assert.Equal(t, float64(88), u["output_tokens"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -173,13 +173,15 @@ func TestDecodeMessage_UserWithOnlyToolResults(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDecodeContentBlocks_Nil(t *testing.T) {
|
func TestDecodeContentBlocks_Nil(t *testing.T) {
|
||||||
blocks := decodeContentBlocks(nil)
|
blocks, err := decodeContentBlocks(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.Len(t, blocks, 1)
|
assert.Len(t, blocks, 1)
|
||||||
assert.Equal(t, "", blocks[0].Text)
|
assert.Equal(t, "", blocks[0].Text)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDecodeContentBlocks_String(t *testing.T) {
|
func TestDecodeContentBlocks_String(t *testing.T) {
|
||||||
blocks := decodeContentBlocks("hello")
|
blocks, err := decodeContentBlocks("hello")
|
||||||
|
require.NoError(t, err)
|
||||||
assert.Len(t, blocks, 1)
|
assert.Len(t, blocks, 1)
|
||||||
assert.Equal(t, "hello", blocks[0].Text)
|
assert.Equal(t, "hello", blocks[0].Text)
|
||||||
}
|
}
|
||||||
@@ -217,8 +219,10 @@ func TestEncodeToolChoice(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := encodeToolChoice(tt.choice)
|
result := encodeToolChoice(tt.choice)
|
||||||
assert.Equal(t, tt.want["type"], result.(map[string]any)["type"])
|
r, ok := result.(map[string]any)
|
||||||
assert.Equal(t, tt.want["name"], result.(map[string]any)["name"])
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, tt.want["type"], r["type"])
|
||||||
|
assert.Equal(t, tt.want["name"], r["name"])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -315,12 +319,15 @@ func TestEncodeRequest_WithTools(t *testing.T) {
|
|||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
tools := result["tools"].([]any)
|
tools, okt := result["tools"].([]any)
|
||||||
|
require.True(t, okt)
|
||||||
assert.Len(t, tools, 1)
|
assert.Len(t, tools, 1)
|
||||||
tool := tools[0].(map[string]any)
|
tool, okt2 := tools[0].(map[string]any)
|
||||||
|
require.True(t, okt2)
|
||||||
assert.Equal(t, "search", tool["name"])
|
assert.Equal(t, "search", tool["name"])
|
||||||
assert.Equal(t, "Search things", tool["description"])
|
assert.Equal(t, "Search things", tool["description"])
|
||||||
tc := result["tool_choice"].(map[string]any)
|
tc, oktc := result["tool_choice"].(map[string]any)
|
||||||
|
require.True(t, oktc)
|
||||||
assert.Equal(t, "auto", tc["type"])
|
assert.Equal(t, "auto", tc["type"])
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -366,7 +373,8 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
|
|||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
usage := result["usage"].(map[string]any)
|
usage, oku := result["usage"].(map[string]any)
|
||||||
|
require.True(t, oku)
|
||||||
assert.Equal(t, float64(100), usage["input_tokens"])
|
assert.Equal(t, float64(100), usage["input_tokens"])
|
||||||
assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
|
assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
|
||||||
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])
|
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])
|
||||||
|
|||||||
@@ -3,10 +3,14 @@ package conversion
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
"nex/backend/internal/conversion/canonical"
|
||||||
|
pkglogger "nex/backend/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HTTPRequestSpec HTTP 请求规格
|
// HTTPRequestSpec HTTP 请求规格
|
||||||
@@ -33,13 +37,10 @@ type ConversionEngine struct {
|
|||||||
|
|
||||||
// NewConversionEngine 创建转换引擎
|
// NewConversionEngine 创建转换引擎
|
||||||
func NewConversionEngine(registry AdapterRegistry, logger *zap.Logger) *ConversionEngine {
|
func NewConversionEngine(registry AdapterRegistry, logger *zap.Logger) *ConversionEngine {
|
||||||
if logger == nil {
|
|
||||||
logger = zap.L()
|
|
||||||
}
|
|
||||||
return &ConversionEngine{
|
return &ConversionEngine{
|
||||||
registry: registry,
|
registry: registry,
|
||||||
middlewareChain: NewMiddlewareChain(),
|
middlewareChain: NewMiddlewareChain(),
|
||||||
logger: logger,
|
logger: pkglogger.WithModule(logger, "conversion.engine"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,7 +73,7 @@ func (e *ConversionEngine) IsPassthrough(clientProtocol, providerProtocol string
|
|||||||
|
|
||||||
// ConvertHttpRequest 转换 HTTP 请求
|
// ConvertHttpRequest 转换 HTTP 请求
|
||||||
func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) {
|
func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) {
|
||||||
nativePath := spec.URL
|
nativePath, rawQuery := splitRequestPath(spec.URL)
|
||||||
|
|
||||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||||
providerAdapter, err := e.registry.Get(providerProtocol)
|
providerAdapter, err := e.registry.Get(providerProtocol)
|
||||||
@@ -90,15 +91,18 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
|||||||
rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType)
|
rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体",
|
e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体",
|
||||||
zap.String("error", err.Error()),
|
zap.Error(err),
|
||||||
zap.String("interface", string(interfaceType)))
|
zap.String("interface", string(interfaceType)))
|
||||||
rewrittenBody = spec.Body
|
rewrittenBody = spec.Body
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||||
|
providerURL = appendRawQuery(providerURL, rawQuery)
|
||||||
|
|
||||||
return &HTTPRequestSpec{
|
return &HTTPRequestSpec{
|
||||||
URL: provider.BaseURL + nativePath,
|
URL: joinBaseURL(provider.BaseURL, providerURL),
|
||||||
Method: spec.Method,
|
Method: spec.Method,
|
||||||
Headers: providerAdapter.BuildHeaders(provider),
|
Headers: providerAdapter.BuildHeaders(provider),
|
||||||
Body: rewrittenBody,
|
Body: rewrittenBody,
|
||||||
@@ -115,7 +119,8 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
|||||||
}
|
}
|
||||||
|
|
||||||
interfaceType := clientAdapter.DetectInterfaceType(nativePath)
|
interfaceType := clientAdapter.DetectInterfaceType(nativePath)
|
||||||
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType)
|
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||||
|
providerURL = appendRawQuery(providerURL, rawQuery)
|
||||||
providerHeaders := providerAdapter.BuildHeaders(provider)
|
providerHeaders := providerAdapter.BuildHeaders(provider)
|
||||||
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
|
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -123,7 +128,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &HTTPRequestSpec{
|
return &HTTPRequestSpec{
|
||||||
URL: provider.BaseURL + providerUrl,
|
URL: joinBaseURL(provider.BaseURL, providerURL),
|
||||||
Method: spec.Method,
|
Method: spec.Method,
|
||||||
Headers: providerHeaders,
|
Headers: providerHeaders,
|
||||||
Body: providerBody,
|
Body: providerBody,
|
||||||
@@ -135,25 +140,22 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
|
|||||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||||
// Smart Passthrough: 同协议时最小化改写 model 字段
|
// Smart Passthrough: 同协议时最小化改写 model 字段
|
||||||
if modelOverride != "" && len(spec.Body) > 0 {
|
if modelOverride != "" && len(spec.Body) > 0 {
|
||||||
adapter, err := e.registry.Get(clientProtocol)
|
adapter, getErr := e.registry.Get(clientProtocol)
|
||||||
if err != nil {
|
if getErr == nil {
|
||||||
return &spec, nil
|
rewrittenBody, rewriteErr := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
|
||||||
}
|
if rewriteErr != nil {
|
||||||
|
|
||||||
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
|
|
||||||
if err != nil {
|
|
||||||
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
|
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
|
||||||
zap.String("error", err.Error()),
|
zap.Error(rewriteErr),
|
||||||
zap.String("interface", string(interfaceType)))
|
zap.String("interface", string(interfaceType)))
|
||||||
return &spec, nil
|
} else {
|
||||||
}
|
|
||||||
|
|
||||||
return &HTTPResponseSpec{
|
return &HTTPResponseSpec{
|
||||||
StatusCode: spec.StatusCode,
|
StatusCode: spec.StatusCode,
|
||||||
Headers: spec.Headers,
|
Headers: spec.Headers,
|
||||||
Body: rewrittenBody,
|
Body: rewrittenBody,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return &spec, nil
|
return &spec, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -183,12 +185,11 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
|
|||||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||||
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
|
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
|
||||||
if modelOverride != "" {
|
if modelOverride != "" {
|
||||||
adapter, err := e.registry.Get(clientProtocol)
|
adapter, getErr := e.registry.Get(clientProtocol)
|
||||||
if err != nil {
|
if getErr == nil {
|
||||||
return NewPassthroughStreamConverter(), nil
|
|
||||||
}
|
|
||||||
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
|
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return NewPassthroughStreamConverter(), nil
|
return NewPassthroughStreamConverter(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -203,7 +204,7 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
|
|||||||
|
|
||||||
ctx := ConversionContext{
|
ctx := ConversionContext{
|
||||||
ConversionID: uuid.New().String(),
|
ConversionID: uuid.New().String(),
|
||||||
InterfaceType: InterfaceTypeChat,
|
InterfaceType: interfaceType,
|
||||||
Timestamp: time.Now(),
|
Timestamp: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -273,7 +274,7 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie
|
|||||||
func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||||
canonicalReq, err := clientAdapter.DecodeRequest(body)
|
canonicalReq, err := clientAdapter.DecodeRequest(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewConversionError(ErrorCodeJSONParseError, "解码请求失败").WithCause(err)
|
return nil, NewRequestJSONParseError("解码请求失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := NewConversionContext(InterfaceTypeChat)
|
ctx := NewConversionContext(InterfaceTypeChat)
|
||||||
@@ -281,6 +282,9 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if containsUnsupportedMultimodal(canonicalReq) {
|
||||||
|
return nil, NewConversionError(ErrorCodeUnsupportedMultimodal, "跨协议暂不支持多模态内容")
|
||||||
|
}
|
||||||
|
|
||||||
encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider)
|
encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -292,7 +296,7 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
|
|||||||
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||||
canonicalResp, err := providerAdapter.DecodeResponse(body)
|
canonicalResp, err := providerAdapter.DecodeResponse(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
|
return nil, NewResponseJSONParseError("解码响应失败", err)
|
||||||
}
|
}
|
||||||
if modelOverride != "" {
|
if modelOverride != "" {
|
||||||
canonicalResp.Model = modelOverride
|
canonicalResp.Model = modelOverride
|
||||||
@@ -307,12 +311,12 @@ func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapte
|
|||||||
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||||
models, err := providerAdapter.DecodeModelsResponse(body)
|
models, err := providerAdapter.DecodeModelsResponse(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
|
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.Error(err))
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
encoded, err := clientAdapter.EncodeModelsResponse(models)
|
encoded, err := clientAdapter.EncodeModelsResponse(models)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
|
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.Error(err))
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
return encoded, nil
|
return encoded, nil
|
||||||
@@ -321,12 +325,12 @@ func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdap
|
|||||||
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||||
info, err := providerAdapter.DecodeModelInfoResponse(body)
|
info, err := providerAdapter.DecodeModelInfoResponse(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
encoded, err := clientAdapter.EncodeModelInfoResponse(info)
|
encoded, err := clientAdapter.EncodeModelInfoResponse(info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
return encoded, nil
|
return encoded, nil
|
||||||
@@ -335,7 +339,7 @@ func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerA
|
|||||||
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||||
req, err := clientAdapter.DecodeEmbeddingRequest(body)
|
req, err := clientAdapter.DecodeEmbeddingRequest(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error()))
|
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.Error(err))
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
return providerAdapter.EncodeEmbeddingRequest(req, provider)
|
return providerAdapter.EncodeEmbeddingRequest(req, provider)
|
||||||
@@ -344,7 +348,7 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P
|
|||||||
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||||
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
|
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
|
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.Error(err))
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
if modelOverride != "" {
|
if modelOverride != "" {
|
||||||
@@ -356,29 +360,31 @@ func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerA
|
|||||||
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||||
req, err := clientAdapter.DecodeRerankRequest(body)
|
req, err := clientAdapter.DecodeRerankRequest(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error()))
|
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.Error(err))
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
return providerAdapter.EncodeRerankRequest(req, provider)
|
return providerAdapter.EncodeRerankRequest(req, provider)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||||
resp, err := providerAdapter.DecodeRerankResponse(body)
|
resp, decodeErr := providerAdapter.DecodeRerankResponse(body)
|
||||||
if err != nil {
|
if decodeErr == nil {
|
||||||
return body, nil
|
|
||||||
}
|
|
||||||
if modelOverride != "" {
|
if modelOverride != "" {
|
||||||
resp.Model = modelOverride
|
resp.Model = modelOverride
|
||||||
}
|
}
|
||||||
return clientAdapter.EncodeRerankResponse(resp)
|
return clientAdapter.EncodeRerankResponse(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DetectInterfaceType 检测接口类型
|
// DetectInterfaceType 检测接口类型
|
||||||
func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string) (InterfaceType, error) {
|
func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string) (InterfaceType, error) {
|
||||||
adapter, err := e.registry.Get(clientProtocol)
|
adapter, err := e.registry.Get(clientProtocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return InterfaceTypePassthrough, err
|
return InterfaceTypePassthrough, err
|
||||||
}
|
}
|
||||||
|
nativePath, _ = splitRequestPath(nativePath)
|
||||||
return adapter.DetectInterfaceType(nativePath), nil
|
return adapter.DetectInterfaceType(nativePath), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -392,9 +398,56 @@ func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol stri
|
|||||||
"type": "internal_error",
|
"type": "internal_error",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
body, _ := json.Marshal(fallback)
|
body, marshalErr := json.Marshal(fallback)
|
||||||
|
if marshalErr == nil {
|
||||||
return body, 500, nil
|
return body, 500, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return []byte(`{"error":{"message":"internal error","type":"internal_error"}}`), 500, nil
|
||||||
|
}
|
||||||
body, statusCode := adapter.EncodeError(err)
|
body, statusCode := adapter.EncodeError(err)
|
||||||
return body, statusCode, nil
|
return body, statusCode, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func splitRequestPath(rawPath string) (string, string) {
|
||||||
|
path, query, found := strings.Cut(rawPath, "?")
|
||||||
|
if !found {
|
||||||
|
return rawPath, ""
|
||||||
|
}
|
||||||
|
return path, query
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendRawQuery(path, rawQuery string) string {
|
||||||
|
if rawQuery == "" {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
if strings.Contains(path, "?") {
|
||||||
|
return path + "&" + rawQuery
|
||||||
|
}
|
||||||
|
return path + "?" + rawQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinBaseURL(baseURL, path string) string {
|
||||||
|
if baseURL == "" {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
if path == "" {
|
||||||
|
return baseURL
|
||||||
|
}
|
||||||
|
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsUnsupportedMultimodal(req *canonical.CanonicalRequest) bool {
|
||||||
|
if req == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
for _, block := range msg.Content {
|
||||||
|
switch block.Type {
|
||||||
|
case "image", "audio", "video", "file":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
63
backend/internal/conversion/engine_adapter_test.go
Normal file
63
backend/internal/conversion/engine_adapter_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package conversion_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"nex/backend/internal/conversion"
|
||||||
|
"nex/backend/internal/conversion/anthropic"
|
||||||
|
"nex/backend/internal/conversion/openai"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConvertHttpRequest_SameProtocolUsesAdapterBuildURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
adapter conversion.ProtocolAdapter
|
||||||
|
clientProtocol string
|
||||||
|
providerProtocol string
|
||||||
|
baseURL string
|
||||||
|
nativePath string
|
||||||
|
expectedURL string
|
||||||
|
body []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "openai base url includes version path",
|
||||||
|
adapter: openai.NewAdapter(),
|
||||||
|
clientProtocol: "openai",
|
||||||
|
providerProtocol: "openai",
|
||||||
|
baseURL: "http://example.com/v1",
|
||||||
|
nativePath: "/chat/completions",
|
||||||
|
expectedURL: "http://example.com/v1/chat/completions",
|
||||||
|
body: []byte(`{"model":"gpt-4","messages":[]}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "anthropic native path keeps v1",
|
||||||
|
adapter: anthropic.NewAdapter(),
|
||||||
|
clientProtocol: "anthropic",
|
||||||
|
providerProtocol: "anthropic",
|
||||||
|
baseURL: "http://example.com",
|
||||||
|
nativePath: "/v1/messages",
|
||||||
|
expectedURL: "http://example.com/v1/messages",
|
||||||
|
body: []byte(`{"model":"claude","messages":[]}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
registry := conversion.NewMemoryRegistry()
|
||||||
|
engine := conversion.NewConversionEngine(registry, zap.NewNop())
|
||||||
|
require.NoError(t, registry.Register(tt.adapter))
|
||||||
|
|
||||||
|
out, err := engine.ConvertHttpRequest(conversion.HTTPRequestSpec{
|
||||||
|
URL: tt.nativePath,
|
||||||
|
Method: "POST",
|
||||||
|
Body: tt.body,
|
||||||
|
}, tt.clientProtocol, tt.providerProtocol, conversion.NewTargetProvider(tt.baseURL, "key", "upstream-model"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expectedURL, out.URL)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConversionError_WithProviderProtocol(t *testing.T) {
|
func TestConversionError_WithProviderProtocol(t *testing.T) {
|
||||||
@@ -39,7 +40,7 @@ func TestConversionError_FullBuilder(t *testing.T) {
|
|||||||
|
|
||||||
func TestEngine_Use(t *testing.T) {
|
func TestEngine_Use(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
called := false
|
called := false
|
||||||
engine.Use(&testMiddleware{fn: func(req *canonical.CanonicalRequest, cp, pp string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
|
engine.Use(&testMiddleware{fn: func(req *canonical.CanonicalRequest, cp, pp string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
|
||||||
called = true
|
called = true
|
||||||
@@ -66,7 +67,7 @@ func TestEngine_Use(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertHttpRequest_DecodeError(t *testing.T) {
|
func TestConvertHttpRequest_DecodeError(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
clientAdapter := newMockAdapter("client", false)
|
clientAdapter := newMockAdapter("client", false)
|
||||||
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
|
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||||
return nil, errors.New("decode failed")
|
return nil, errors.New("decode failed")
|
||||||
@@ -82,7 +83,7 @@ func TestConvertHttpRequest_DecodeError(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertHttpRequest_EncodeError(t *testing.T) {
|
func TestConvertHttpRequest_EncodeError(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||||
providerAdapter := newMockAdapter("provider", false)
|
providerAdapter := newMockAdapter("provider", false)
|
||||||
providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) {
|
providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) {
|
||||||
@@ -98,7 +99,7 @@ func TestConvertHttpRequest_EncodeError(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
|
func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
clientAdapter := newMockAdapter("client", false)
|
clientAdapter := newMockAdapter("client", false)
|
||||||
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
|
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||||
@@ -121,7 +122,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertHttpResponse_DecodeError(t *testing.T) {
|
func TestConvertHttpResponse_DecodeError(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
providerAdapter := newMockAdapter("provider", false)
|
providerAdapter := newMockAdapter("provider", false)
|
||||||
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
|
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||||
return nil, errors.New("decode error")
|
return nil, errors.New("decode error")
|
||||||
@@ -135,7 +136,7 @@ func TestConvertHttpResponse_DecodeError(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
|
func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
clientAdapter := newMockAdapter("client", false)
|
clientAdapter := newMockAdapter("client", false)
|
||||||
clientAdapter.ifaceType = InterfaceTypeEmbeddings
|
clientAdapter.ifaceType = InterfaceTypeEmbeddings
|
||||||
@@ -158,7 +159,7 @@ func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertHttpRequest_RerankInterface(t *testing.T) {
|
func TestConvertHttpRequest_RerankInterface(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
clientAdapter := newMockAdapter("client", false)
|
clientAdapter := newMockAdapter("client", false)
|
||||||
clientAdapter.ifaceType = InterfaceTypeRerank
|
clientAdapter.ifaceType = InterfaceTypeRerank
|
||||||
@@ -178,7 +179,7 @@ func TestConvertHttpRequest_RerankInterface(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
|
func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
clientAdapter := newMockAdapter("client", false)
|
clientAdapter := newMockAdapter("client", false)
|
||||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||||
@@ -196,7 +197,7 @@ func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertHttpResponse_RerankInterface(t *testing.T) {
|
func TestConvertHttpResponse_RerankInterface(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
clientAdapter := newMockAdapter("client", false)
|
clientAdapter := newMockAdapter("client", false)
|
||||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||||
@@ -214,7 +215,7 @@ func TestConvertHttpResponse_RerankInterface(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
|
func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
clientAdapter := newMockAdapter("client", false)
|
clientAdapter := newMockAdapter("client", false)
|
||||||
clientAdapter.ifaceType = InterfaceTypeModels
|
clientAdapter.ifaceType = InterfaceTypeModels
|
||||||
providerAdapter := newMockAdapter("provider", false)
|
providerAdapter := newMockAdapter("provider", false)
|
||||||
@@ -232,7 +233,7 @@ func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
|
func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
clientAdapter := newMockAdapter("client", false)
|
clientAdapter := newMockAdapter("client", false)
|
||||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModels: true}
|
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModels: true}
|
||||||
providerAdapter := newMockAdapter("provider", false)
|
providerAdapter := newMockAdapter("provider", false)
|
||||||
@@ -249,7 +250,7 @@ func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
|
func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
clientAdapter := newMockAdapter("client", false)
|
clientAdapter := newMockAdapter("client", false)
|
||||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModelInfo: true}
|
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModelInfo: true}
|
||||||
providerAdapter := newMockAdapter("provider", false)
|
providerAdapter := newMockAdapter("provider", false)
|
||||||
@@ -324,7 +325,7 @@ var _ = json.Marshal
|
|||||||
|
|
||||||
func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
|
func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
clientAdapter := newMockAdapter("client", false)
|
clientAdapter := newMockAdapter("client", false)
|
||||||
clientAdapter.decodeEmbeddingReqFn = func(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
clientAdapter.decodeEmbeddingReqFn = func(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||||
@@ -344,7 +345,7 @@ func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertRerankBody_DecodeError(t *testing.T) {
|
func TestConvertRerankBody_DecodeError(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
clientAdapter := newMockAdapter("client", false)
|
clientAdapter := newMockAdapter("client", false)
|
||||||
clientAdapter.decodeRerankReqFn = func(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
clientAdapter.decodeRerankReqFn = func(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||||
@@ -364,7 +365,7 @@ func TestConvertRerankBody_DecodeError(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertBody_UnknownInterfaceType(t *testing.T) {
|
func TestConvertBody_UnknownInterfaceType(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
clientAdapter := newMockAdapter("client", false)
|
clientAdapter := newMockAdapter("client", false)
|
||||||
providerAdapter := newMockAdapter("provider", false)
|
providerAdapter := newMockAdapter("provider", false)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package conversion
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"nex/backend/internal/conversion/canonical"
|
"nex/backend/internal/conversion/canonical"
|
||||||
@@ -190,7 +191,9 @@ func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel str
|
|||||||
// noopStreamDecoder 空流式解码器
|
// noopStreamDecoder 空流式解码器
|
||||||
type noopStreamDecoder struct{}
|
type noopStreamDecoder struct{}
|
||||||
|
|
||||||
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil }
|
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
|
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
|
||||||
|
|
||||||
// noopStreamEncoder 空流式编码器
|
// noopStreamEncoder 空流式编码器
|
||||||
@@ -203,7 +206,7 @@ func (e *noopStreamEncoder) Flush() [][]byte
|
|||||||
|
|
||||||
func TestNewConversionEngine(t *testing.T) {
|
func TestNewConversionEngine(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
assert.NotNil(t, engine)
|
assert.NotNil(t, engine)
|
||||||
assert.Equal(t, registry, engine.GetRegistry())
|
assert.Equal(t, registry, engine.GetRegistry())
|
||||||
}
|
}
|
||||||
@@ -211,7 +214,7 @@ func TestNewConversionEngine(t *testing.T) {
|
|||||||
func TestNewConversionEngine_LoggerInjection(t *testing.T) {
|
func TestNewConversionEngine_LoggerInjection(t *testing.T) {
|
||||||
t.Run("nil_logger_uses_global", func(t *testing.T) {
|
t.Run("nil_logger_uses_global", func(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
assert.NotNil(t, engine.logger)
|
assert.NotNil(t, engine.logger)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -219,13 +222,14 @@ func TestNewConversionEngine_LoggerInjection(t *testing.T) {
|
|||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
customLogger := zap.NewNop()
|
customLogger := zap.NewNop()
|
||||||
engine := NewConversionEngine(registry, customLogger)
|
engine := NewConversionEngine(registry, customLogger)
|
||||||
assert.Equal(t, customLogger, engine.logger)
|
assert.NotNil(t, engine.logger)
|
||||||
|
assert.Contains(t, engine.logger.Name(), "conversion.engine")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegisterAdapter(t *testing.T) {
|
func TestRegisterAdapter(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
adapter := newMockAdapter("test-proto", true)
|
adapter := newMockAdapter("test-proto", true)
|
||||||
err := engine.RegisterAdapter(adapter)
|
err := engine.RegisterAdapter(adapter)
|
||||||
@@ -237,7 +241,7 @@ func TestRegisterAdapter(t *testing.T) {
|
|||||||
|
|
||||||
func TestIsPassthrough_SameProtocol(t *testing.T) {
|
func TestIsPassthrough_SameProtocol(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
adapter := newMockAdapter("openai", true)
|
adapter := newMockAdapter("openai", true)
|
||||||
_ = engine.RegisterAdapter(adapter)
|
_ = engine.RegisterAdapter(adapter)
|
||||||
|
|
||||||
@@ -246,7 +250,7 @@ func TestIsPassthrough_SameProtocol(t *testing.T) {
|
|||||||
|
|
||||||
func TestIsPassthrough_DifferentProtocol(t *testing.T) {
|
func TestIsPassthrough_DifferentProtocol(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||||
_ = engine.RegisterAdapter(newMockAdapter("anthropic", true))
|
_ = engine.RegisterAdapter(newMockAdapter("anthropic", true))
|
||||||
|
|
||||||
@@ -255,7 +259,7 @@ func TestIsPassthrough_DifferentProtocol(t *testing.T) {
|
|||||||
|
|
||||||
func TestIsPassthrough_NoPassthrough(t *testing.T) {
|
func TestIsPassthrough_NoPassthrough(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
_ = engine.RegisterAdapter(newMockAdapter("custom", false))
|
_ = engine.RegisterAdapter(newMockAdapter("custom", false))
|
||||||
|
|
||||||
assert.False(t, engine.IsPassthrough("custom", "custom"))
|
assert.False(t, engine.IsPassthrough("custom", "custom"))
|
||||||
@@ -263,7 +267,7 @@ func TestIsPassthrough_NoPassthrough(t *testing.T) {
|
|||||||
|
|
||||||
func TestDetectInterfaceType(t *testing.T) {
|
func TestDetectInterfaceType(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
adapter := newMockAdapter("test", true)
|
adapter := newMockAdapter("test", true)
|
||||||
adapter.ifaceType = InterfaceTypeChat
|
adapter.ifaceType = InterfaceTypeChat
|
||||||
_ = engine.RegisterAdapter(adapter)
|
_ = engine.RegisterAdapter(adapter)
|
||||||
@@ -275,7 +279,7 @@ func TestDetectInterfaceType(t *testing.T) {
|
|||||||
|
|
||||||
func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
|
func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
_, err := engine.DetectInterfaceType("/v1/chat", "nonexistent")
|
_, err := engine.DetectInterfaceType("/v1/chat", "nonexistent")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
@@ -283,25 +287,39 @@ func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertHttpRequest_Passthrough(t *testing.T) {
|
func TestConvertHttpRequest_Passthrough(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
openaiAdapter := &buildURLMockAdapter{
|
||||||
|
mockProtocolAdapter: newMockAdapter("openai", true),
|
||||||
|
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
|
||||||
|
if interfaceType == InterfaceTypeChat {
|
||||||
|
return "/chat/completions"
|
||||||
|
}
|
||||||
|
return nativePath
|
||||||
|
},
|
||||||
|
}
|
||||||
|
openaiAdapter.ifaceType = InterfaceTypeChat
|
||||||
|
openaiAdapter.supportsIface[InterfaceTypeChat] = true
|
||||||
|
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||||
|
return []byte(`{"model":"` + newModel + `","messages":[{"role":"user","content":"hi"}]}`), nil
|
||||||
|
}
|
||||||
|
_ = engine.RegisterAdapter(openaiAdapter)
|
||||||
|
|
||||||
provider := NewTargetProvider("https://api.openai.com/v1", "sk-test", "gpt-4")
|
provider := NewTargetProvider("https://api.openai.com/v1", "sk-test", "gpt-4")
|
||||||
spec := HTTPRequestSpec{
|
spec := HTTPRequestSpec{
|
||||||
URL: "/chat/completions",
|
URL: "/v1/chat/completions",
|
||||||
Method: "POST",
|
Method: "POST",
|
||||||
Body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`),
|
Body: []byte(`{"model":"openai/gpt-4","messages":[{"role":"user","content":"hi"}]}`),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := engine.ConvertHttpRequest(spec, "openai", "openai", provider)
|
result, err := engine.ConvertHttpRequest(spec, "openai", "openai", provider)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
|
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
|
||||||
assert.Equal(t, spec.Body, result.Body)
|
assert.JSONEq(t, `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`, string(result.Body))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
|
func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
clientAdapter := newMockAdapter("client-proto", false)
|
clientAdapter := newMockAdapter("client-proto", false)
|
||||||
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
|
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||||
@@ -331,9 +349,80 @@ func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
|
|||||||
assert.NotNil(t, result.Body)
|
assert.NotNil(t, result.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertHttpRequest_UsesProviderAdapterBuildURL(t *testing.T) {
|
||||||
|
registry := NewMemoryRegistry()
|
||||||
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
openaiAdapter := &buildURLMockAdapter{
|
||||||
|
mockProtocolAdapter: newMockAdapter("openai", true),
|
||||||
|
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
|
||||||
|
if interfaceType == InterfaceTypeChat {
|
||||||
|
return "/chat/completions"
|
||||||
|
}
|
||||||
|
return nativePath
|
||||||
|
},
|
||||||
|
}
|
||||||
|
openaiAdapter.ifaceType = InterfaceTypeChat
|
||||||
|
openaiAdapter.supportsIface[InterfaceTypeChat] = true
|
||||||
|
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||||
|
return []byte(`{"model":"` + newModel + `"}`), nil
|
||||||
|
}
|
||||||
|
require.NoError(t, registry.Register(openaiAdapter))
|
||||||
|
|
||||||
|
anthropicAdapter := &buildURLMockAdapter{
|
||||||
|
mockProtocolAdapter: newMockAdapter("anthropic", false),
|
||||||
|
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
|
||||||
|
if interfaceType == InterfaceTypeChat {
|
||||||
|
return "/v1/messages"
|
||||||
|
}
|
||||||
|
return nativePath
|
||||||
|
},
|
||||||
|
}
|
||||||
|
anthropicAdapter.ifaceType = InterfaceTypeChat
|
||||||
|
anthropicAdapter.supportsIface[InterfaceTypeChat] = true
|
||||||
|
require.NoError(t, registry.Register(anthropicAdapter))
|
||||||
|
|
||||||
|
t.Run("OpenAI to Anthropic", func(t *testing.T) {
|
||||||
|
provider := NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||||
|
spec := HTTPRequestSpec{
|
||||||
|
URL: "/v1/chat/completions",
|
||||||
|
Method: "POST",
|
||||||
|
Body: []byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"max_tokens":16}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := engine.ConvertHttpRequest(spec, "openai", "anthropic", provider)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "https://api.anthropic.com/v1/messages", result.URL)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Anthropic to OpenAI", func(t *testing.T) {
|
||||||
|
provider := NewTargetProvider("https://api.openai.com/v1", "key", "gpt-4")
|
||||||
|
spec := HTTPRequestSpec{
|
||||||
|
URL: "/v1/messages",
|
||||||
|
Method: "POST",
|
||||||
|
Body: []byte(`{"model":"p1/claude-3","max_tokens":16,"messages":[{"role":"user","content":"hi"}]}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := engine.ConvertHttpRequest(spec, "anthropic", "openai", provider)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type buildURLMockAdapter struct {
|
||||||
|
*mockProtocolAdapter
|
||||||
|
buildURLFn func(string, InterfaceType) string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *buildURLMockAdapter) BuildUrl(nativePath string, interfaceType InterfaceType) string {
|
||||||
|
if m.buildURLFn != nil {
|
||||||
|
return m.buildURLFn(nativePath, interfaceType)
|
||||||
|
}
|
||||||
|
return m.mockProtocolAdapter.BuildUrl(nativePath, interfaceType)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertHttpResponse_Passthrough(t *testing.T) {
|
func TestConvertHttpResponse_Passthrough(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||||
|
|
||||||
spec := HTTPResponseSpec{
|
spec := HTTPResponseSpec{
|
||||||
@@ -349,7 +438,7 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) {
|
|||||||
|
|
||||||
func TestCreateStreamConverter_Passthrough(t *testing.T) {
|
func TestCreateStreamConverter_Passthrough(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||||
|
|
||||||
converter, err := engine.CreateStreamConverter("openai", "openai", "", InterfaceTypeChat)
|
converter, err := engine.CreateStreamConverter("openai", "openai", "", InterfaceTypeChat)
|
||||||
@@ -360,7 +449,7 @@ func TestCreateStreamConverter_Passthrough(t *testing.T) {
|
|||||||
|
|
||||||
func TestCreateStreamConverter_Canonical(t *testing.T) {
|
func TestCreateStreamConverter_Canonical(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||||
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
|
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
|
||||||
|
|
||||||
@@ -372,7 +461,7 @@ func TestCreateStreamConverter_Canonical(t *testing.T) {
|
|||||||
|
|
||||||
func TestEncodeError(t *testing.T) {
|
func TestEncodeError(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||||
|
|
||||||
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
|
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
|
||||||
@@ -384,7 +473,7 @@ func TestEncodeError(t *testing.T) {
|
|||||||
|
|
||||||
func TestEncodeError_NonExistentProtocol(t *testing.T) {
|
func TestEncodeError_NonExistentProtocol(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
|
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
|
||||||
body, statusCode, err := engine.EncodeError(convErr, "nonexistent")
|
body, statusCode, err := engine.EncodeError(convErr, "nonexistent")
|
||||||
@@ -417,7 +506,7 @@ func TestRegistry_GetNonExistent(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
|
func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
clientAdapter := newMockAdapter("client", false)
|
clientAdapter := newMockAdapter("client", false)
|
||||||
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
|
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||||
@@ -446,7 +535,7 @@ func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
|
func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
// 使用真实 OpenAI adapter 验证 Smart Passthrough 改写
|
// 使用真实 OpenAI adapter 验证 Smart Passthrough 改写
|
||||||
openaiAdapter := newMockAdapter("openai", true)
|
openaiAdapter := newMockAdapter("openai", true)
|
||||||
@@ -476,7 +565,7 @@ func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
|
|||||||
|
|
||||||
func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
|
func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
openaiAdapter := newMockAdapter("openai", true)
|
openaiAdapter := newMockAdapter("openai", true)
|
||||||
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||||
@@ -495,18 +584,19 @@ func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
|
|||||||
_, ok := converter.(*SmartPassthroughStreamConverter)
|
_, ok := converter.(*SmartPassthroughStreamConverter)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
// 验证 chunk 改写
|
// 验证 SSE frame 中的 data JSON 被改写
|
||||||
chunks := converter.ProcessChunk([]byte(`{"model":"gpt-4","choices":[]}`))
|
chunks := converter.ProcessChunk([]byte(`data: {"model":"gpt-4","choices":[]}` + "\n\n"))
|
||||||
require.Len(t, chunks, 1)
|
require.Len(t, chunks, 1)
|
||||||
|
|
||||||
var resp map[string]interface{}
|
var resp map[string]interface{}
|
||||||
require.NoError(t, json.Unmarshal(chunks[0], &resp))
|
payload := strings.TrimPrefix(strings.TrimSpace(string(chunks[0])), "data: ")
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(payload), &resp))
|
||||||
assert.Equal(t, "openai/gpt-4", resp["model"])
|
assert.Equal(t, "openai/gpt-4", resp["model"])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
|
func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
// provider adapter 解码出含 model 的流式事件
|
// provider adapter 解码出含 model 的流式事件
|
||||||
providerAdapter := newMockAdapter("provider", false)
|
providerAdapter := newMockAdapter("provider", false)
|
||||||
@@ -560,7 +650,7 @@ func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
|
|||||||
|
|
||||||
func TestCreateStreamConverter_ModelOverride_CrossProtocol_Empty(t *testing.T) {
|
func TestCreateStreamConverter_ModelOverride_CrossProtocol_Empty(t *testing.T) {
|
||||||
registry := NewMemoryRegistry()
|
registry := NewMemoryRegistry()
|
||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, zap.NewNop())
|
||||||
|
|
||||||
providerAdapter := newMockAdapter("provider", false)
|
providerAdapter := newMockAdapter("provider", false)
|
||||||
providerAdapter.streamDecoderFn = func() StreamDecoder {
|
providerAdapter.streamDecoderFn = func() StreamDecoder {
|
||||||
@@ -614,6 +704,7 @@ func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.Canonical
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||||
if d.flushFn != nil {
|
if d.flushFn != nil {
|
||||||
return d.flushFn()
|
return d.flushFn()
|
||||||
@@ -633,6 +724,7 @@ func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEve
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *engineTestStreamEncoder) Flush() [][]byte {
|
func (e *engineTestStreamEncoder) Flush() [][]byte {
|
||||||
if e.flushFn != nil {
|
if e.flushFn != nil {
|
||||||
return e.flushFn()
|
return e.flushFn()
|
||||||
|
|||||||
@@ -17,6 +17,13 @@ const (
|
|||||||
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
|
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
|
||||||
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
|
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
|
||||||
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
|
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
|
||||||
|
ErrorCodeUnsupportedMultimodal ErrorCode = "UNSUPPORTED_MULTIMODAL"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ErrorDetailPhase = "phase"
|
||||||
|
ErrorPhaseRequest = "request"
|
||||||
|
ErrorPhaseResponse = "response"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConversionError 协议转换错误
|
// ConversionError 协议转换错误
|
||||||
@@ -39,6 +46,20 @@ func NewConversionError(code ErrorCode, message string) *ConversionError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewRequestJSONParseError 创建请求 JSON 解析错误。
|
||||||
|
func NewRequestJSONParseError(message string, cause error) *ConversionError {
|
||||||
|
return NewConversionError(ErrorCodeJSONParseError, message).
|
||||||
|
WithDetail(ErrorDetailPhase, ErrorPhaseRequest).
|
||||||
|
WithCause(cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResponseJSONParseError 创建响应 JSON 解析错误。
|
||||||
|
func NewResponseJSONParseError(message string, cause error) *ConversionError {
|
||||||
|
return NewConversionError(ErrorCodeJSONParseError, message).
|
||||||
|
WithDetail(ErrorDetailPhase, ErrorPhaseResponse).
|
||||||
|
WithCause(cause)
|
||||||
|
}
|
||||||
|
|
||||||
// WithClientProtocol 设置客户端协议
|
// WithClientProtocol 设置客户端协议
|
||||||
func (e *ConversionError) WithClientProtocol(protocol string) *ConversionError {
|
func (e *ConversionError) WithClientProtocol(protocol string) *ConversionError {
|
||||||
e.ClientProtocol = protocol
|
e.ClientProtocol = protocol
|
||||||
|
|||||||
@@ -29,27 +29,27 @@ func (a *Adapter) SupportsPassthrough() bool { return true }
|
|||||||
// DetectInterfaceType 根据路径检测接口类型
|
// DetectInterfaceType 根据路径检测接口类型
|
||||||
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
|
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
|
||||||
switch {
|
switch {
|
||||||
case nativePath == "/chat/completions":
|
case nativePath == "/v1/chat/completions":
|
||||||
return conversion.InterfaceTypeChat
|
return conversion.InterfaceTypeChat
|
||||||
case nativePath == "/models":
|
case nativePath == "/v1/models":
|
||||||
return conversion.InterfaceTypeModels
|
return conversion.InterfaceTypeModels
|
||||||
case isModelInfoPath(nativePath):
|
case isModelInfoPath(nativePath):
|
||||||
return conversion.InterfaceTypeModelInfo
|
return conversion.InterfaceTypeModelInfo
|
||||||
case nativePath == "/embeddings":
|
case nativePath == "/v1/embeddings":
|
||||||
return conversion.InterfaceTypeEmbeddings
|
return conversion.InterfaceTypeEmbeddings
|
||||||
case nativePath == "/rerank":
|
case nativePath == "/v1/rerank":
|
||||||
return conversion.InterfaceTypeRerank
|
return conversion.InterfaceTypeRerank
|
||||||
default:
|
default:
|
||||||
return conversion.InterfaceTypePassthrough
|
return conversion.InterfaceTypePassthrough
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isModelInfoPath 判断是否为模型详情路径(/models/{id},允许 id 含 /)
|
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /)
|
||||||
func isModelInfoPath(path string) bool {
|
func isModelInfoPath(path string) bool {
|
||||||
if !strings.HasPrefix(path, "/models/") {
|
if !strings.HasPrefix(path, "/v1/models/") {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
suffix := path[len("/models/"):]
|
suffix := path[len("/v1/models/"):]
|
||||||
return suffix != ""
|
return suffix != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,6 +60,11 @@ func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.Interface
|
|||||||
return "/chat/completions"
|
return "/chat/completions"
|
||||||
case conversion.InterfaceTypeModels:
|
case conversion.InterfaceTypeModels:
|
||||||
return "/models"
|
return "/models"
|
||||||
|
case conversion.InterfaceTypeModelInfo:
|
||||||
|
if modelID, err := a.ExtractUnifiedModelID(nativePath); err == nil {
|
||||||
|
return "/models/" + modelID
|
||||||
|
}
|
||||||
|
return nativePath
|
||||||
case conversion.InterfaceTypeEmbeddings:
|
case conversion.InterfaceTypeEmbeddings:
|
||||||
return "/embeddings"
|
return "/embeddings"
|
||||||
case conversion.InterfaceTypeRerank:
|
case conversion.InterfaceTypeRerank:
|
||||||
@@ -138,7 +143,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
|||||||
Code: string(err.Code),
|
Code: string(err.Code),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
body, _ := json.Marshal(errMsg)
|
body, marshalErr := json.Marshal(errMsg)
|
||||||
|
if marshalErr != nil {
|
||||||
|
return []byte(`{"error":{"message":"internal error","type":"internal_error","code":"INTERNAL_ERROR"}}`), statusCode
|
||||||
|
}
|
||||||
return body, statusCode
|
return body, statusCode
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -218,12 +226,12 @@ func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse)
|
|||||||
return encodeRerankResponse(resp)
|
return encodeRerankResponse(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExtractUnifiedModelID 从路径中提取统一模型 ID(/models/{provider_id}/{model_name})
|
// ExtractUnifiedModelID 从路径中提取统一模型 ID(/v1/models/{provider_id}/{model_name})
|
||||||
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||||
if !strings.HasPrefix(nativePath, "/models/") {
|
if !strings.HasPrefix(nativePath, "/v1/models/") {
|
||||||
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
|
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
|
||||||
}
|
}
|
||||||
suffix := nativePath[len("/models/"):]
|
suffix := nativePath[len("/v1/models/"):]
|
||||||
if suffix == "" {
|
if suffix == "" {
|
||||||
return "", fmt.Errorf("路径缺少模型 ID")
|
return "", fmt.Errorf("路径缺少模型 ID")
|
||||||
}
|
}
|
||||||
@@ -248,7 +256,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
|
|||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
rewriteFunc := func(newModel string) ([]byte, error) {
|
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||||
m["model"], _ = json.Marshal(newModel)
|
encodedModel, err := json.Marshal(newModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m["model"] = encodedModel
|
||||||
return json.Marshal(m)
|
return json.Marshal(m)
|
||||||
}
|
}
|
||||||
return current, rewriteFunc, nil
|
return current, rewriteFunc, nil
|
||||||
@@ -282,12 +294,20 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
|
|||||||
switch ifaceType {
|
switch ifaceType {
|
||||||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
|
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
|
||||||
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
|
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
|
||||||
m["model"], _ = json.Marshal(newModel)
|
encodedModel, err := json.Marshal(newModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m["model"] = encodedModel
|
||||||
return json.Marshal(m)
|
return json.Marshal(m)
|
||||||
case conversion.InterfaceTypeRerank:
|
case conversion.InterfaceTypeRerank:
|
||||||
// Rerank 响应:存在 model 字段则改写,不存在则不添加
|
// Rerank 响应:存在 model 字段则改写,不存在则不添加
|
||||||
if _, exists := m["model"]; exists {
|
if _, exists := m["model"]; exists {
|
||||||
m["model"], _ = json.Marshal(newModel)
|
encodedModel, err := json.Marshal(newModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m["model"] = encodedModel
|
||||||
}
|
}
|
||||||
return json.Marshal(m)
|
return json.Marshal(m)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -28,11 +28,11 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
|
|||||||
path string
|
path string
|
||||||
expected conversion.InterfaceType
|
expected conversion.InterfaceType
|
||||||
}{
|
}{
|
||||||
{"聊天补全", "/chat/completions", conversion.InterfaceTypeChat},
|
{"聊天补全", "/v1/chat/completions", conversion.InterfaceTypeChat},
|
||||||
{"模型列表", "/models", conversion.InterfaceTypeModels},
|
{"模型列表", "/v1/models", conversion.InterfaceTypeModels},
|
||||||
{"模型详情", "/models/gpt-4", conversion.InterfaceTypeModelInfo},
|
{"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo},
|
||||||
{"嵌入接口", "/embeddings", conversion.InterfaceTypeEmbeddings},
|
{"嵌入接口", "/v1/embeddings", conversion.InterfaceTypeEmbeddings},
|
||||||
{"重排序接口", "/rerank", conversion.InterfaceTypeRerank},
|
{"重排序接口", "/v1/rerank", conversion.InterfaceTypeRerank},
|
||||||
{"未知路径", "/unknown", conversion.InterfaceTypePassthrough},
|
{"未知路径", "/unknown", conversion.InterfaceTypePassthrough},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,6 +44,27 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdapter_OldPathsBecomePassthrough(t *testing.T) {
|
||||||
|
a := NewAdapter()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
expected conversion.InterfaceType
|
||||||
|
}{
|
||||||
|
{"/chat/completions", conversion.InterfaceTypePassthrough},
|
||||||
|
{"/models", conversion.InterfaceTypePassthrough},
|
||||||
|
{"/models/gpt-4.1", conversion.InterfaceTypePassthrough},
|
||||||
|
{"/embeddings", conversion.InterfaceTypePassthrough},
|
||||||
|
{"/rerank", conversion.InterfaceTypePassthrough},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.path, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAdapter_BuildUrl(t *testing.T) {
|
func TestAdapter_BuildUrl(t *testing.T) {
|
||||||
a := NewAdapter()
|
a := NewAdapter()
|
||||||
|
|
||||||
@@ -53,10 +74,12 @@ func TestAdapter_BuildUrl(t *testing.T) {
|
|||||||
interfaceType conversion.InterfaceType
|
interfaceType conversion.InterfaceType
|
||||||
expected string
|
expected string
|
||||||
}{
|
}{
|
||||||
{"聊天", "/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
|
{"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
|
||||||
{"模型", "/models", conversion.InterfaceTypeModels, "/models"},
|
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/models"},
|
||||||
{"嵌入", "/embeddings", conversion.InterfaceTypeEmbeddings, "/embeddings"},
|
{"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo, "/models/openai/gpt-4"},
|
||||||
{"重排序", "/rerank", conversion.InterfaceTypeRerank, "/rerank"},
|
{"复杂模型详情", "/v1/models/azure/accounts/org/models/gpt-4", conversion.InterfaceTypeModelInfo, "/models/azure/accounts/org/models/gpt-4"},
|
||||||
|
{"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/embeddings"},
|
||||||
|
{"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/rerank"},
|
||||||
{"默认透传", "/other", conversion.InterfaceTypePassthrough, "/other"},
|
{"默认透传", "/other", conversion.InterfaceTypePassthrough, "/other"},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,12 +141,12 @@ func TestIsModelInfoPath(t *testing.T) {
|
|||||||
path string
|
path string
|
||||||
expected bool
|
expected bool
|
||||||
}{
|
}{
|
||||||
{"model_info", "/models/gpt-4", true},
|
{"model_info", "/v1/models/openai/gpt-4", true},
|
||||||
{"model_info_with_dots", "/models/gpt-4.1-preview", true},
|
{"model_info_with_dots", "/v1/models/openai/gpt-4.1-preview", true},
|
||||||
{"models_list", "/models", false},
|
{"models_list", "/v1/models", false},
|
||||||
{"nested_path", "/models/gpt-4/versions", true},
|
{"nested_path", "/v1/models/azure/accounts/org-123/models/gpt-4", true},
|
||||||
{"empty_suffix", "/models/", false},
|
{"empty_suffix", "/v1/models/", false},
|
||||||
{"unrelated", "/chat/completions", false},
|
{"unrelated", "/v1/chat/completions", false},
|
||||||
{"partial_prefix", "/model", false},
|
{"partial_prefix", "/model", false},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -134,6 +157,27 @@ func TestIsModelInfoPath(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdapter_ExtractUnifiedModelID(t *testing.T) {
|
||||||
|
a := NewAdapter()
|
||||||
|
|
||||||
|
t.Run("标准路径", func(t *testing.T) {
|
||||||
|
modelID, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "openai/gpt-4", modelID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("复杂路径", func(t *testing.T) {
|
||||||
|
modelID, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "azure/accounts/org/models/gpt-4", modelID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("非模型详情路径报错", func(t *testing.T) {
|
||||||
|
_, err := a.ExtractUnifiedModelID("/v1/models")
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestAdapter_EncodeError_InvalidInput(t *testing.T) {
|
func TestAdapter_EncodeError_InvalidInput(t *testing.T) {
|
||||||
a := NewAdapter()
|
a := NewAdapter()
|
||||||
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")
|
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")
|
||||||
|
|||||||
@@ -18,35 +18,35 @@ func TestExtractUnifiedModelID(t *testing.T) {
|
|||||||
a := NewAdapter()
|
a := NewAdapter()
|
||||||
|
|
||||||
t.Run("standard_path", func(t *testing.T) {
|
t.Run("standard_path", func(t *testing.T) {
|
||||||
id, err := a.ExtractUnifiedModelID("/models/openai/gpt-4")
|
id, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "openai/gpt-4", id)
|
assert.Equal(t, "openai/gpt-4", id)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("multi_segment_path", func(t *testing.T) {
|
t.Run("multi_segment_path", func(t *testing.T) {
|
||||||
id, err := a.ExtractUnifiedModelID("/models/azure/accounts/org/models/gpt-4")
|
id, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "azure/accounts/org/models/gpt-4", id)
|
assert.Equal(t, "azure/accounts/org/models/gpt-4", id)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("single_segment", func(t *testing.T) {
|
t.Run("single_segment", func(t *testing.T) {
|
||||||
id, err := a.ExtractUnifiedModelID("/models/gpt-4")
|
id, err := a.ExtractUnifiedModelID("/v1/models/gpt-4")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "gpt-4", id)
|
assert.Equal(t, "gpt-4", id)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("non_model_path", func(t *testing.T) {
|
t.Run("non_model_path", func(t *testing.T) {
|
||||||
_, err := a.ExtractUnifiedModelID("/chat/completions")
|
_, err := a.ExtractUnifiedModelID("/v1/chat/completions")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("empty_suffix", func(t *testing.T) {
|
t.Run("empty_suffix", func(t *testing.T) {
|
||||||
_, err := a.ExtractUnifiedModelID("/models/")
|
_, err := a.ExtractUnifiedModelID("/v1/models/")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("models_list_no_slash", func(t *testing.T) {
|
t.Run("models_list_no_slash", func(t *testing.T) {
|
||||||
_, err := a.ExtractUnifiedModelID("/models")
|
_, err := a.ExtractUnifiedModelID("/v1/models")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -344,12 +344,12 @@ func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
|
|||||||
path string
|
path string
|
||||||
expected bool
|
expected bool
|
||||||
}{
|
}{
|
||||||
{"simple_model_id", "/models/gpt-4", true},
|
{"simple_model_id", "/v1/models/gpt-4", true},
|
||||||
{"unified_model_id_with_slash", "/models/openai/gpt-4", true},
|
{"unified_model_id_with_slash", "/v1/models/openai/gpt-4", true},
|
||||||
{"models_list", "/models", false},
|
{"models_list", "/v1/models", false},
|
||||||
{"models_list_trailing_slash", "/models/", false},
|
{"models_list_trailing_slash", "/v1/models/", false},
|
||||||
{"chat_completions", "/chat/completions", false},
|
{"chat_completions", "/v1/chat/completions", false},
|
||||||
{"deeply_nested", "/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
|
{"deeply_nested", "/v1/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
@@ -215,10 +215,16 @@ func decodeUserContent(content any) []canonical.ContentBlock {
|
|||||||
var blocks []canonical.ContentBlock
|
var blocks []canonical.ContentBlock
|
||||||
for _, item := range v {
|
for _, item := range v {
|
||||||
if m, ok := item.(map[string]any); ok {
|
if m, ok := item.(map[string]any); ok {
|
||||||
t, _ := m["type"].(string)
|
t, ok := m["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
switch t {
|
switch t {
|
||||||
case "text":
|
case "text":
|
||||||
text, _ := m["text"].(string)
|
text, ok := m["text"].(string)
|
||||||
|
if !ok {
|
||||||
|
text = ""
|
||||||
|
}
|
||||||
blocks = append(blocks, canonical.NewTextBlock(text))
|
blocks = append(blocks, canonical.NewTextBlock(text))
|
||||||
case "image_url":
|
case "image_url":
|
||||||
blocks = append(blocks, canonical.ContentBlock{Type: "image"})
|
blocks = append(blocks, canonical.ContentBlock{Type: "image"})
|
||||||
@@ -256,13 +262,22 @@ func decodeContentParts(content any) []contentPart {
|
|||||||
var result []contentPart
|
var result []contentPart
|
||||||
for _, item := range parts {
|
for _, item := range parts {
|
||||||
if m, ok := item.(map[string]any); ok {
|
if m, ok := item.(map[string]any); ok {
|
||||||
t, _ := m["type"].(string)
|
t, ok := m["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
switch t {
|
switch t {
|
||||||
case "text":
|
case "text":
|
||||||
text, _ := m["text"].(string)
|
text, ok := m["text"].(string)
|
||||||
|
if !ok {
|
||||||
|
text = ""
|
||||||
|
}
|
||||||
result = append(result, contentPart{Type: "text", Text: text})
|
result = append(result, contentPart{Type: "text", Text: text})
|
||||||
case "refusal":
|
case "refusal":
|
||||||
refusal, _ := m["refusal"].(string)
|
refusal, ok := m["refusal"].(string)
|
||||||
|
if !ok {
|
||||||
|
refusal = ""
|
||||||
|
}
|
||||||
result = append(result, contentPart{Type: "refusal", Refusal: refusal})
|
result = append(result, contentPart{Type: "refusal", Refusal: refusal})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -307,21 +322,33 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
|||||||
return canonical.NewToolChoiceAny()
|
return canonical.NewToolChoiceAny()
|
||||||
}
|
}
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
t, _ := v["type"].(string)
|
t, ok := v["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
switch t {
|
switch t {
|
||||||
case "function":
|
case "function":
|
||||||
if fn, ok := v["function"].(map[string]any); ok {
|
if fn, ok := v["function"].(map[string]any); ok {
|
||||||
name, _ := fn["name"].(string)
|
name, ok := fn["name"].(string)
|
||||||
|
if !ok {
|
||||||
|
name = ""
|
||||||
|
}
|
||||||
return canonical.NewToolChoiceNamed(name)
|
return canonical.NewToolChoiceNamed(name)
|
||||||
}
|
}
|
||||||
case "custom":
|
case "custom":
|
||||||
if custom, ok := v["custom"].(map[string]any); ok {
|
if custom, ok := v["custom"].(map[string]any); ok {
|
||||||
name, _ := custom["name"].(string)
|
name, ok := custom["name"].(string)
|
||||||
|
if !ok {
|
||||||
|
name = ""
|
||||||
|
}
|
||||||
return canonical.NewToolChoiceNamed(name)
|
return canonical.NewToolChoiceNamed(name)
|
||||||
}
|
}
|
||||||
case "allowed_tools":
|
case "allowed_tools":
|
||||||
if at, ok := v["allowed_tools"].(map[string]any); ok {
|
if at, ok := v["allowed_tools"].(map[string]any); ok {
|
||||||
mode, _ := at["mode"].(string)
|
mode, ok := at["mode"].(string)
|
||||||
|
if !ok {
|
||||||
|
mode = ""
|
||||||
|
}
|
||||||
if mode == "required" {
|
if mode == "required" {
|
||||||
return canonical.NewToolChoiceAny()
|
return canonical.NewToolChoiceAny()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,9 +45,11 @@ func TestEncodeRequest_SystemInjection(t *testing.T) {
|
|||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
msgs := result["messages"].([]any)
|
msgs, ok := result["messages"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Len(t, msgs, 2)
|
assert.Len(t, msgs, 2)
|
||||||
firstMsg := msgs[0].(map[string]any)
|
firstMsg, ok := msgs[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Equal(t, "system", firstMsg["role"])
|
assert.Equal(t, "system", firstMsg["role"])
|
||||||
assert.Equal(t, "你是助手", firstMsg["content"])
|
assert.Equal(t, "你是助手", firstMsg["content"])
|
||||||
}
|
}
|
||||||
@@ -72,12 +74,15 @@ func TestEncodeRequest_ToolCalls(t *testing.T) {
|
|||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
msgs := result["messages"].([]any)
|
msgs, ok := result["messages"].([]any)
|
||||||
assistantMsg := msgs[0].(map[string]any)
|
require.True(t, ok)
|
||||||
|
assistantMsg, ok := msgs[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
toolCalls, ok := assistantMsg["tool_calls"].([]any)
|
toolCalls, ok := assistantMsg["tool_calls"].([]any)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Len(t, toolCalls, 1)
|
assert.Len(t, toolCalls, 1)
|
||||||
tc := toolCalls[0].(map[string]any)
|
tc, ok := toolCalls[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Equal(t, "call_1", tc["id"])
|
assert.Equal(t, "call_1", tc["id"])
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,9 +120,12 @@ func TestEncodeResponse_Basic(t *testing.T) {
|
|||||||
assert.Equal(t, "resp-1", result["id"])
|
assert.Equal(t, "resp-1", result["id"])
|
||||||
assert.Equal(t, "chat.completion", result["object"])
|
assert.Equal(t, "chat.completion", result["object"])
|
||||||
|
|
||||||
choices := result["choices"].([]any)
|
choices, ok := result["choices"].([]any)
|
||||||
choice := choices[0].(map[string]any)
|
require.True(t, ok)
|
||||||
msg := choice["message"].(map[string]any)
|
choice, ok := choices[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
msg, ok := choice["message"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
assert.Equal(t, "你好", msg["content"])
|
assert.Equal(t, "你好", msg["content"])
|
||||||
assert.Equal(t, "stop", choice["finish_reason"])
|
assert.Equal(t, "stop", choice["finish_reason"])
|
||||||
}
|
}
|
||||||
@@ -137,8 +145,12 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
|
|||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
choices := result["choices"].([]any)
|
choices, okc := result["choices"].([]any)
|
||||||
msg := choices[0].(map[string]any)["message"].(map[string]any)
|
require.True(t, okc)
|
||||||
|
msgMap, okm := choices[0].(map[string]any)
|
||||||
|
require.True(t, okm)
|
||||||
|
msg, okmsg := msgMap["message"].(map[string]any)
|
||||||
|
require.True(t, okmsg)
|
||||||
tcs, ok := msg["tool_calls"].([]any)
|
tcs, ok := msg["tool_calls"].([]any)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Len(t, tcs, 1)
|
assert.Len(t, tcs, 1)
|
||||||
@@ -158,7 +170,8 @@ func TestEncodeModelsResponse(t *testing.T) {
|
|||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
assert.Equal(t, "list", result["object"])
|
assert.Equal(t, "list", result["object"])
|
||||||
data := result["data"].([]any)
|
data, okd := result["data"].([]any)
|
||||||
|
require.True(t, okd)
|
||||||
assert.Len(t, data, 2)
|
assert.Len(t, data, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -317,8 +330,12 @@ func TestEncodeResponse_Thinking(t *testing.T) {
|
|||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
choices := result["choices"].([]any)
|
choices, okch := result["choices"].([]any)
|
||||||
msg := choices[0].(map[string]any)["message"].(map[string]any)
|
require.True(t, okch)
|
||||||
|
msgMap, okmm := choices[0].(map[string]any)
|
||||||
|
require.True(t, okmm)
|
||||||
|
msg, okmsg := msgMap["message"].(map[string]any)
|
||||||
|
require.True(t, okmsg)
|
||||||
assert.Equal(t, "回答", msg["content"])
|
assert.Equal(t, "回答", msg["content"])
|
||||||
assert.Equal(t, "思考过程", msg["reasoning_content"])
|
assert.Equal(t, "思考过程", msg["reasoning_content"])
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,8 +27,12 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
|
|||||||
data := strings.TrimPrefix(s, "data: ")
|
data := strings.TrimPrefix(s, "data: ")
|
||||||
data = strings.TrimRight(data, "\n")
|
data = strings.TrimRight(data, "\n")
|
||||||
require.NoError(t, json.Unmarshal([]byte(data), &payload))
|
require.NoError(t, json.Unmarshal([]byte(data), &payload))
|
||||||
choices := payload["choices"].([]any)
|
choices, okch := payload["choices"].([]any)
|
||||||
delta := choices[0].(map[string]any)["delta"].(map[string]any)
|
require.True(t, okch)
|
||||||
|
msgMap, okmm := choices[0].(map[string]any)
|
||||||
|
require.True(t, okmm)
|
||||||
|
delta, okd := msgMap["delta"].(map[string]any)
|
||||||
|
require.True(t, okd)
|
||||||
assert.Equal(t, "assistant", delta["role"])
|
assert.Equal(t, "assistant", delta["role"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -177,7 +177,8 @@ func TestEncodeRerankResponse(t *testing.T) {
|
|||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
assert.Equal(t, "rerank-1", result["model"])
|
assert.Equal(t, "rerank-1", result["model"])
|
||||||
results := result["results"].([]any)
|
results, okr := result["results"].([]any)
|
||||||
|
require.True(t, okr)
|
||||||
assert.Len(t, results, 1)
|
assert.Len(t, results, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -373,7 +374,8 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
|
|||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
usage := result["usage"].(map[string]any)
|
usage, oku := result["usage"].(map[string]any)
|
||||||
|
require.True(t, oku)
|
||||||
assert.Equal(t, float64(100), usage["prompt_tokens"])
|
assert.Equal(t, float64(100), usage["prompt_tokens"])
|
||||||
ptd, ok := usage["prompt_tokens_details"].(map[string]any)
|
ptd, ok := usage["prompt_tokens_details"].(map[string]any)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
@@ -412,8 +414,10 @@ func TestEncodeResponse_StopReasons(t *testing.T) {
|
|||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
require.NoError(t, json.Unmarshal(body, &result))
|
require.NoError(t, json.Unmarshal(body, &result))
|
||||||
choices := result["choices"].([]any)
|
choices, okch := result["choices"].([]any)
|
||||||
choice := choices[0].(map[string]any)
|
require.True(t, okch)
|
||||||
|
choice, okc := choices[0].(map[string]any)
|
||||||
|
require.True(t, okc)
|
||||||
assert.Equal(t, tt.want, choice["finish_reason"])
|
assert.Equal(t, tt.want, choice["finish_reason"])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
package conversion
|
package conversion
|
||||||
|
|
||||||
import "nex/backend/internal/conversion/canonical"
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"nex/backend/internal/conversion/canonical"
|
||||||
|
)
|
||||||
|
|
||||||
// StreamDecoder 流式解码器接口
|
// StreamDecoder 流式解码器接口
|
||||||
type StreamDecoder interface {
|
type StreamDecoder interface {
|
||||||
@@ -39,11 +44,12 @@ func (c *PassthroughStreamConverter) Flush() [][]byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
|
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
|
||||||
// 逐 chunk 改写 model 字段
|
// 按 SSE frame 改写 data JSON 中的 model 字段
|
||||||
type SmartPassthroughStreamConverter struct {
|
type SmartPassthroughStreamConverter struct {
|
||||||
adapter ProtocolAdapter
|
adapter ProtocolAdapter
|
||||||
modelOverride string
|
modelOverride string
|
||||||
interfaceType InterfaceType
|
interfaceType InterfaceType
|
||||||
|
buffer []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
|
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
|
||||||
@@ -55,25 +61,46 @@ func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessChunk 改写 chunk 中的 model 字段
|
// ProcessChunk 按 SSE frame 改写 data JSON 中的 model 字段
|
||||||
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||||
if len(rawChunk) == 0 {
|
if len(rawChunk) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rewrittenChunk, err := c.adapter.RewriteResponseModelName(rawChunk, c.modelOverride, c.interfaceType)
|
c.buffer = append(c.buffer, rawChunk...)
|
||||||
|
frames, rest := splitSSEFrames(c.buffer)
|
||||||
|
c.buffer = rest
|
||||||
|
|
||||||
|
result := make([][]byte, 0, len(frames))
|
||||||
|
for _, frame := range frames {
|
||||||
|
result = append(result, c.rewriteFrame(frame))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SmartPassthroughStreamConverter) rewriteFrame(frame []byte) []byte {
|
||||||
|
payload, ok := sseFrameDataPayload(frame)
|
||||||
|
if !ok || strings.TrimSpace(payload) == "[DONE]" {
|
||||||
|
return frame
|
||||||
|
}
|
||||||
|
|
||||||
|
rewrittenPayload, err := c.adapter.RewriteResponseModelName([]byte(payload), c.modelOverride, c.interfaceType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 改写失败,返回原始 chunk
|
return frame
|
||||||
return [][]byte{rawChunk}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return [][]byte{rewrittenChunk}
|
return rebuildSSEFrameWithData(frame, string(rewrittenPayload))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush 无缓冲数据
|
// Flush 输出未形成完整 frame 的剩余数据
|
||||||
func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
|
func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
|
||||||
|
if len(c.buffer) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
frame := append([]byte(nil), c.buffer...)
|
||||||
|
c.buffer = nil
|
||||||
|
return [][]byte{c.rewriteFrame(frame)}
|
||||||
|
}
|
||||||
|
|
||||||
// CanonicalStreamConverter 跨协议规范流式转换器
|
// CanonicalStreamConverter 跨协议规范流式转换器
|
||||||
type CanonicalStreamConverter struct {
|
type CanonicalStreamConverter struct {
|
||||||
@@ -153,3 +180,86 @@ func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.Canonical
|
|||||||
event.Message.Model = c.modelOverride
|
event.Message.Model = c.modelOverride
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func splitSSEFrames(data []byte) ([][]byte, []byte) {
|
||||||
|
var frames [][]byte
|
||||||
|
for len(data) > 0 {
|
||||||
|
idx, sepLen := findSSEFrameSeparator(data)
|
||||||
|
if idx < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
end := idx + sepLen
|
||||||
|
frames = append(frames, append([]byte(nil), data[:end]...))
|
||||||
|
data = data[end:]
|
||||||
|
}
|
||||||
|
return frames, data
|
||||||
|
}
|
||||||
|
|
||||||
|
func findSSEFrameSeparator(data []byte) (int, int) {
|
||||||
|
lf := bytes.Index(data, []byte("\n\n"))
|
||||||
|
crlf := bytes.Index(data, []byte("\r\n\r\n"))
|
||||||
|
switch {
|
||||||
|
case lf < 0 && crlf < 0:
|
||||||
|
return -1, 0
|
||||||
|
case lf < 0:
|
||||||
|
return crlf, 4
|
||||||
|
case crlf < 0:
|
||||||
|
return lf, 2
|
||||||
|
case crlf <= lf:
|
||||||
|
return crlf, 4
|
||||||
|
default:
|
||||||
|
return lf, 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sseFrameDataPayload(frame []byte) (string, bool) {
|
||||||
|
text := strings.TrimRight(string(frame), "\r\n")
|
||||||
|
lines := strings.Split(text, "\n")
|
||||||
|
var dataLines []string
|
||||||
|
for _, line := range lines {
|
||||||
|
line = strings.TrimRight(line, "\r")
|
||||||
|
if strings.HasPrefix(line, "data:") {
|
||||||
|
value := strings.TrimPrefix(line, "data:")
|
||||||
|
if strings.HasPrefix(value, " ") {
|
||||||
|
value = value[1:]
|
||||||
|
}
|
||||||
|
dataLines = append(dataLines, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(dataLines) == 0 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return strings.Join(dataLines, "\n"), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func rebuildSSEFrameWithData(frame []byte, data string) []byte {
|
||||||
|
lineEnding, separator := sseLineEnding(frame)
|
||||||
|
text := strings.TrimRight(string(frame), "\r\n")
|
||||||
|
lines := strings.Split(text, "\n")
|
||||||
|
out := make([]string, 0, len(lines)+1)
|
||||||
|
dataWritten := false
|
||||||
|
for _, line := range lines {
|
||||||
|
line = strings.TrimRight(line, "\r")
|
||||||
|
if strings.HasPrefix(line, "data:") {
|
||||||
|
if !dataWritten {
|
||||||
|
for _, dataLine := range strings.Split(data, "\n") {
|
||||||
|
out = append(out, "data: "+dataLine)
|
||||||
|
}
|
||||||
|
dataWritten = true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, line)
|
||||||
|
}
|
||||||
|
if !dataWritten {
|
||||||
|
out = append(out, "data: "+data)
|
||||||
|
}
|
||||||
|
return []byte(strings.Join(out, lineEnding) + separator)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sseLineEnding(frame []byte) (string, string) {
|
||||||
|
if bytes.Contains(frame, []byte("\r\n")) {
|
||||||
|
return "\r\n", "\r\n\r\n"
|
||||||
|
}
|
||||||
|
return "\n", "\n\n"
|
||||||
|
}
|
||||||
|
|||||||
135
backend/internal/database/database.go
Normal file
135
backend/internal/database/database.go
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/pressly/goose/v3"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"gorm.io/driver/mysql"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"nex/backend/internal/config"
|
||||||
|
"nex/backend/migrations"
|
||||||
|
pkglogger "nex/backend/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrMigration = errors.New("数据库迁移失败")
|
||||||
|
|
||||||
|
func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
|
||||||
|
moduleLogger := pkglogger.WithModule(zapLogger, "database")
|
||||||
|
|
||||||
|
db, err := initDB(cfg, moduleLogger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("初始化数据库失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := runMigrations(db, cfg.Driver, moduleLogger); err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: %w", ErrMigration, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
configurePool(db, cfg, moduleLogger)
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Close(db *gorm.DB) {
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sqlDB.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func initDB(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
|
||||||
|
gormLogger := pkglogger.NewGormLogger(zapLogger)
|
||||||
|
|
||||||
|
gormConfig := &gorm.Config{
|
||||||
|
Logger: gormLogger,
|
||||||
|
}
|
||||||
|
|
||||||
|
switch cfg.Driver {
|
||||||
|
case "mysql":
|
||||||
|
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||||
|
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName)
|
||||||
|
if zapLogger != nil {
|
||||||
|
zapLogger.Info("连接 MySQL 数据库",
|
||||||
|
zap.String("host", cfg.Host),
|
||||||
|
zap.Int("port", cfg.Port),
|
||||||
|
zap.String("database", cfg.DBName))
|
||||||
|
}
|
||||||
|
return gorm.Open(mysql.Open(dsn), gormConfig)
|
||||||
|
default:
|
||||||
|
dbDir := filepath.Dir(cfg.Path)
|
||||||
|
if err := os.MkdirAll(dbDir, 0o755); err != nil {
|
||||||
|
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
|
||||||
|
}
|
||||||
|
if zapLogger != nil {
|
||||||
|
zapLogger.Info("连接 SQLite 数据库", zap.String("path", cfg.Path))
|
||||||
|
}
|
||||||
|
return gorm.Open(sqlite.Open(cfg.Path), gormConfig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runMigrations(db *gorm.DB, driver string, zapLogger *zap.Logger) error {
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dialect, fsys, err := migrations.ForDriver(driver)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if zapLogger != nil {
|
||||||
|
zapLogger.Info("执行数据库迁移",
|
||||||
|
zap.String("dialect", string(dialect)),
|
||||||
|
zap.String("driver", driver))
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, err := goose.NewProvider(dialect, sqlDB, fsys)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("创建迁移提供者失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := provider.Up(context.Background()); err != nil {
|
||||||
|
return fmt.Errorf("执行迁移失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func configurePool(db *gorm.DB, cfg *config.DatabaseConfig, zapLogger *zap.Logger) {
|
||||||
|
if cfg.Driver == "sqlite" {
|
||||||
|
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||||
|
if zapLogger != nil {
|
||||||
|
zapLogger.Warn("启用 WAL 模式失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||||
|
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||||
|
sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
|
||||||
|
|
||||||
|
if zapLogger != nil {
|
||||||
|
zapLogger.Info("数据库连接池配置",
|
||||||
|
zap.Int("max_idle_conns", cfg.MaxIdleConns),
|
||||||
|
zap.Int("max_open_conns", cfg.MaxOpenConns),
|
||||||
|
zap.Duration("conn_max_lifetime", cfg.ConnMaxLifetime))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildDSN(cfg *config.DatabaseConfig) string {
|
||||||
|
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||||
|
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName)
|
||||||
|
}
|
||||||
165
backend/internal/database/database_test.go
Normal file
165
backend/internal/database/database_test.go
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"nex/backend/internal/config"
|
||||||
|
"nex/backend/migrations"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInit_SQLite(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfg := &config.DatabaseConfig{
|
||||||
|
Driver: "sqlite",
|
||||||
|
Path: filepath.Join(dir, "test.db"),
|
||||||
|
MaxIdleConns: 5,
|
||||||
|
MaxOpenConns: 10,
|
||||||
|
ConnMaxLifetime: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
zapLogger := zap.NewNop()
|
||||||
|
db, err := Init(cfg, zapLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, db)
|
||||||
|
defer Close(db)
|
||||||
|
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, sqlDB)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClose(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfg := &config.DatabaseConfig{
|
||||||
|
Driver: "sqlite",
|
||||||
|
Path: filepath.Join(dir, "test.db"),
|
||||||
|
MaxIdleConns: 5,
|
||||||
|
MaxOpenConns: 10,
|
||||||
|
ConnMaxLifetime: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
zapLogger := zap.NewNop()
|
||||||
|
db, err := Init(cfg, zapLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, db)
|
||||||
|
|
||||||
|
Close(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildDSN(t *testing.T) {
|
||||||
|
cfg := &config.DatabaseConfig{
|
||||||
|
Driver: "mysql",
|
||||||
|
Host: "db.example.com",
|
||||||
|
Port: 3306,
|
||||||
|
User: "nexuser",
|
||||||
|
Password: "secretpass",
|
||||||
|
DBName: "nexdb",
|
||||||
|
}
|
||||||
|
|
||||||
|
dsn := BuildDSN(cfg)
|
||||||
|
assert.Equal(t, "nexuser:secretpass@tcp(db.example.com:3306)/nexdb?charset=utf8mb4&parseTime=true&loc=Local", dsn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildDSN_EmptyPassword(t *testing.T) {
|
||||||
|
cfg := &config.DatabaseConfig{
|
||||||
|
Driver: "mysql",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 3306,
|
||||||
|
User: "root",
|
||||||
|
DBName: "nex",
|
||||||
|
}
|
||||||
|
|
||||||
|
dsn := BuildDSN(cfg)
|
||||||
|
assert.Equal(t, "root:@tcp(localhost:3306)/nex?charset=utf8mb4&parseTime=true&loc=Local", dsn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInit_SQLite_AnyCWD(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
origDir, err := os.Getwd()
|
||||||
|
if err == nil {
|
||||||
|
defer func() {
|
||||||
|
if chdirErr := os.Chdir(origDir); chdirErr != nil {
|
||||||
|
t.Logf("无法恢复工作目录: %v", chdirErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
if chdirErr := os.Chdir(dir); chdirErr != nil {
|
||||||
|
t.Skipf("无法切换到临时目录: %v", chdirErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.DatabaseConfig{
|
||||||
|
Driver: "sqlite",
|
||||||
|
Path: filepath.Join(dir, "test.db"),
|
||||||
|
MaxIdleConns: 5,
|
||||||
|
MaxOpenConns: 10,
|
||||||
|
ConnMaxLifetime: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
zapLogger := zap.NewNop()
|
||||||
|
db, err := Init(cfg, zapLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, db)
|
||||||
|
defer Close(db)
|
||||||
|
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, sqlDB)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForDriverDialect_SQLite(t *testing.T) {
|
||||||
|
require.NoError(t, testMigrateWithDriver(t, "sqlite"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForDriverDialect_MySQL(t *testing.T) {
|
||||||
|
dialect, fsys, err := migrations.ForDriver("mysql")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "mysql", string(dialect))
|
||||||
|
entries, fsErr := fs.ReadDir(fsys, ".")
|
||||||
|
require.NoError(t, fsErr)
|
||||||
|
assert.NotEmpty(t, entries, "MySQL 迁移资源应至少包含一个文件")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForDriverDialect_Invalid(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfg := &config.DatabaseConfig{
|
||||||
|
Driver: "postgres",
|
||||||
|
Path: filepath.Join(dir, "test.db"),
|
||||||
|
MaxIdleConns: 5,
|
||||||
|
MaxOpenConns: 10,
|
||||||
|
ConnMaxLifetime: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
zapLogger := zap.NewNop()
|
||||||
|
_, err := Init(cfg, zapLogger)
|
||||||
|
assert.Error(t, err, "非法 driver 应返回错误")
|
||||||
|
assert.Contains(t, err.Error(), "不支持的数据库驱动")
|
||||||
|
}
|
||||||
|
|
||||||
|
func testMigrateWithDriver(t *testing.T, driver string) error {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfg := &config.DatabaseConfig{
|
||||||
|
Driver: driver,
|
||||||
|
Path: filepath.Join(dir, "test.db"),
|
||||||
|
MaxIdleConns: 5,
|
||||||
|
MaxOpenConns: 10,
|
||||||
|
ConnMaxLifetime: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
zapLogger := zap.NewNop()
|
||||||
|
db, err := Init(cfg, zapLogger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
Close(db)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
71
backend/internal/database/embedded_migration_test.go
Normal file
71
backend/internal/database/embedded_migration_test.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"nex/backend/migrations"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEmbeddedMigrations_SQLiteResourcesPresent(t *testing.T) {
|
||||||
|
entries, err := fs.ReadDir(migrations.FS, "sqlite")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var sqlFiles []string
|
||||||
|
for _, entry := range entries {
|
||||||
|
if !entry.IsDir() {
|
||||||
|
sqlFiles = append(sqlFiles, entry.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.NotEmpty(t, sqlFiles, "SQLite 迁移资源应至少包含一个 .sql 文件")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedMigrations_MySQLResourcesPresent(t *testing.T) {
|
||||||
|
entries, err := fs.ReadDir(migrations.FS, "mysql")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var sqlFiles []string
|
||||||
|
for _, entry := range entries {
|
||||||
|
if !entry.IsDir() {
|
||||||
|
sqlFiles = append(sqlFiles, entry.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.NotEmpty(t, sqlFiles, "MySQL 迁移资源应至少包含一个 .sql 文件")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedMigrations_SQLiteSQLParsable(t *testing.T) {
|
||||||
|
subFS, err := fs.Sub(migrations.FS, "sqlite")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
entries, err := fs.ReadDir(subFS, ".")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data, err := fs.ReadFile(subFS, entry.Name())
|
||||||
|
require.NoError(t, err, "无法读取迁移文件: %s", entry.Name())
|
||||||
|
assert.NotEmpty(t, data, "迁移文件内容不应为空: %s", entry.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedMigrations_MySQLSQLParsable(t *testing.T) {
|
||||||
|
subFS, err := fs.Sub(migrations.FS, "mysql")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
entries, err := fs.ReadDir(subFS, ".")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data, err := fs.ReadFile(subFS, entry.Name())
|
||||||
|
require.NoError(t, err, "无法读取迁移文件: %s", entry.Name())
|
||||||
|
assert.NotEmpty(t, data, "迁移文件内容不应为空: %s", entry.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,4 +13,3 @@ type Provider struct {
|
|||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,13 +6,13 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"nex/backend/internal/domain"
|
||||||
|
"nex/backend/tests/mocks"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/mock/gomock"
|
"go.uber.org/mock/gomock"
|
||||||
|
|
||||||
"nex/backend/internal/domain"
|
|
||||||
"nex/backend/tests/mocks"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProviderHandler_CreateProvider_Success(t *testing.T) {
|
func TestProviderHandler_CreateProvider_Success(t *testing.T) {
|
||||||
|
|||||||
@@ -9,23 +9,22 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"nex/backend/internal/domain"
|
||||||
|
"nex/backend/tests/mocks"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/mock/gomock"
|
"go.uber.org/mock/gomock"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"nex/backend/internal/domain"
|
|
||||||
appErrors "nex/backend/pkg/errors"
|
appErrors "nex/backend/pkg/errors"
|
||||||
"nex/backend/tests/mocks"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|||||||
@@ -5,9 +5,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
pkglogger "nex/backend/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Logging 日志中间件
|
|
||||||
func Logging(logger *zap.Logger) gin.HandlerFunc {
|
func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
@@ -15,12 +16,16 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
|
|||||||
query := c.Request.URL.RawQuery
|
query := c.Request.URL.RawQuery
|
||||||
|
|
||||||
requestID, _ := c.Get(RequestIDKey)
|
requestID, _ := c.Get(RequestIDKey)
|
||||||
logger.Info("请求开始",
|
var requestIDStr string
|
||||||
zap.String("method", c.Request.Method),
|
if id, ok := requestID.(string); ok {
|
||||||
zap.String("path", path),
|
requestIDStr = id
|
||||||
zap.String("query", query),
|
}
|
||||||
zap.String("client_ip", c.ClientIP()),
|
logger.Debug("请求开始",
|
||||||
zap.Any("request_id", requestID),
|
pkglogger.Method(c.Request.Method),
|
||||||
|
pkglogger.Path(path),
|
||||||
|
pkglogger.Query(query),
|
||||||
|
pkglogger.ClientIP(c.ClientIP()),
|
||||||
|
pkglogger.RequestID(requestIDStr),
|
||||||
)
|
)
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
@@ -28,13 +33,13 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
|
|||||||
latency := time.Since(start)
|
latency := time.Since(start)
|
||||||
statusCode := c.Writer.Status()
|
statusCode := c.Writer.Status()
|
||||||
|
|
||||||
logger.Info("请求结束",
|
logger.Debug("请求结束",
|
||||||
zap.Int("status", statusCode),
|
pkglogger.StatusCode(statusCode),
|
||||||
zap.String("method", c.Request.Method),
|
pkglogger.Method(c.Request.Method),
|
||||||
zap.String("path", path),
|
pkglogger.Path(path),
|
||||||
zap.Duration("latency", latency),
|
pkglogger.Latency(latency),
|
||||||
zap.Int("body_size", c.Writer.Size()),
|
pkglogger.BodySize(c.Writer.Size()),
|
||||||
zap.Any("request_id", requestID),
|
pkglogger.RequestID(requestIDStr),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zapcore"
|
||||||
|
"go.uber.org/zap/zaptest/observer"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -65,6 +67,61 @@ func TestLogging(t *testing.T) {
|
|||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLogging_DoesNotLogLifecycleAtInfoLevel(t *testing.T) {
|
||||||
|
core, logs := observer.New(zapcore.InfoLevel)
|
||||||
|
logger := zap.New(core)
|
||||||
|
|
||||||
|
w := serveLoggingRequest(logger)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
assert.Empty(t, logs.FilterMessage("请求开始").All())
|
||||||
|
assert.Empty(t, logs.FilterMessage("请求结束").All())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogging_LogsLifecycleAtDebugLevel(t *testing.T) {
|
||||||
|
core, logs := observer.New(zapcore.DebugLevel)
|
||||||
|
logger := zap.New(core)
|
||||||
|
|
||||||
|
w := serveLoggingRequest(logger)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
startLogs := logs.FilterMessage("请求开始").All()
|
||||||
|
endLogs := logs.FilterMessage("请求结束").All()
|
||||||
|
if assert.Len(t, startLogs, 1) {
|
||||||
|
fields := startLogs[0].ContextMap()
|
||||||
|
assert.Equal(t, "GET", fields["method"])
|
||||||
|
assert.Equal(t, "/test", fields["path"])
|
||||||
|
assert.Equal(t, "key=value", fields["query"])
|
||||||
|
assert.Equal(t, "existing-id-123", fields["request_id"])
|
||||||
|
assert.NotEmpty(t, fields["client_ip"])
|
||||||
|
}
|
||||||
|
if assert.Len(t, endLogs, 1) {
|
||||||
|
fields := endLogs[0].ContextMap()
|
||||||
|
assert.Equal(t, int64(200), fields["status"])
|
||||||
|
assert.Equal(t, "GET", fields["method"])
|
||||||
|
assert.Equal(t, "/test", fields["path"])
|
||||||
|
assert.Equal(t, int64(2), fields["body_size"])
|
||||||
|
assert.Equal(t, "existing-id-123", fields["request_id"])
|
||||||
|
assert.Contains(t, fields, "latency")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func serveLoggingRequest(logger *zap.Logger) *httptest.ResponseRecorder {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestID())
|
||||||
|
r.Use(Logging(logger))
|
||||||
|
r.GET("/test", func(c *gin.Context) {
|
||||||
|
c.String(200, "ok")
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("GET", "/test?key=value", nil)
|
||||||
|
req.Header.Set("X-Request-ID", "existing-id-123")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
func TestRecovery_NoPanic(t *testing.T) {
|
func TestRecovery_NoPanic(t *testing.T) {
|
||||||
logger := zap.NewNop()
|
logger := zap.NewNop()
|
||||||
|
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"nex/backend/internal/domain"
|
||||||
|
"nex/backend/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
appErrors "nex/backend/pkg/errors"
|
appErrors "nex/backend/pkg/errors"
|
||||||
|
|
||||||
"nex/backend/internal/domain"
|
|
||||||
"nex/backend/internal/service"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ModelHandler 模型管理处理器
|
// ModelHandler 模型管理处理器
|
||||||
@@ -58,13 +58,13 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
|
|||||||
|
|
||||||
err := h.modelService.Create(model)
|
err := h.modelService.Create(model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == appErrors.ErrProviderNotFound {
|
if errors.Is(err, appErrors.ErrProviderNotFound) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
"error": "供应商不存在",
|
"error": "供应商不存在",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err == appErrors.ErrDuplicateModel {
|
if errors.Is(err, appErrors.ErrDuplicateModel) {
|
||||||
c.JSON(http.StatusConflict, gin.H{
|
c.JSON(http.StatusConflict, gin.H{
|
||||||
"error": "同一供应商下模型名称已存在",
|
"error": "同一供应商下模型名称已存在",
|
||||||
"code": appErrors.ErrDuplicateModel.Code,
|
"code": appErrors.ErrDuplicateModel.Code,
|
||||||
@@ -101,7 +101,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
|
|||||||
|
|
||||||
model, err := h.modelService.Get(id)
|
model, err := h.modelService.Get(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == gorm.ErrRecordNotFound {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
c.JSON(http.StatusNotFound, gin.H{
|
c.JSON(http.StatusNotFound, gin.H{
|
||||||
"error": "模型未找到",
|
"error": "模型未找到",
|
||||||
})
|
})
|
||||||
@@ -166,7 +166,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
|
|||||||
|
|
||||||
err := h.modelService.Delete(id)
|
err := h.modelService.Delete(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == gorm.ErrRecordNotFound {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
c.JSON(http.StatusNotFound, gin.H{
|
c.JSON(http.StatusNotFound, gin.H{
|
||||||
"error": "模型未找到",
|
"error": "模型未找到",
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"nex/backend/internal/domain"
|
||||||
|
"nex/backend/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
appErrors "nex/backend/pkg/errors"
|
appErrors "nex/backend/pkg/errors"
|
||||||
|
|
||||||
"nex/backend/internal/domain"
|
|
||||||
"nex/backend/internal/service"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProviderHandler 供应商管理处理器
|
// ProviderHandler 供应商管理处理器
|
||||||
@@ -55,7 +55,7 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
|||||||
|
|
||||||
err := h.providerService.Create(provider)
|
err := h.providerService.Create(provider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == appErrors.ErrInvalidProviderID {
|
if errors.Is(err, appErrors.ErrInvalidProviderID) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
"error": appErrors.ErrInvalidProviderID.Message,
|
"error": appErrors.ErrInvalidProviderID.Message,
|
||||||
"code": appErrors.ErrInvalidProviderID.Code,
|
"code": appErrors.ErrInvalidProviderID.Code,
|
||||||
@@ -86,7 +86,7 @@ func (h *ProviderHandler) GetProvider(c *gin.Context) {
|
|||||||
|
|
||||||
provider, err := h.providerService.Get(id)
|
provider, err := h.providerService.Get(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == gorm.ErrRecordNotFound {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
c.JSON(http.StatusNotFound, gin.H{
|
c.JSON(http.StatusNotFound, gin.H{
|
||||||
"error": "供应商未找到",
|
"error": "供应商未找到",
|
||||||
})
|
})
|
||||||
@@ -113,7 +113,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
|||||||
|
|
||||||
err := h.providerService.Update(id, req)
|
err := h.providerService.Update(id, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == gorm.ErrRecordNotFound {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
c.JSON(http.StatusNotFound, gin.H{
|
c.JSON(http.StatusNotFound, gin.H{
|
||||||
"error": "供应商未找到",
|
"error": "供应商未找到",
|
||||||
})
|
})
|
||||||
@@ -145,7 +145,7 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
|
|||||||
|
|
||||||
err := h.providerService.Delete(id)
|
err := h.providerService.Delete(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == gorm.ErrRecordNotFound {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
c.JSON(http.StatusNotFound, gin.H{
|
c.JSON(http.StatusNotFound, gin.H{
|
||||||
"error": "供应商未找到",
|
"error": "供应商未找到",
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -3,19 +3,23 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"go.uber.org/zap"
|
|
||||||
|
|
||||||
"nex/backend/internal/conversion"
|
"nex/backend/internal/conversion"
|
||||||
"nex/backend/internal/conversion/canonical"
|
"nex/backend/internal/conversion/canonical"
|
||||||
"nex/backend/internal/domain"
|
"nex/backend/internal/domain"
|
||||||
"nex/backend/internal/provider"
|
"nex/backend/internal/provider"
|
||||||
"nex/backend/internal/service"
|
"nex/backend/internal/service"
|
||||||
|
appErrors "nex/backend/pkg/errors"
|
||||||
"nex/backend/pkg/modelid"
|
"nex/backend/pkg/modelid"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
pkglogger "nex/backend/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProxyHandler 统一代理处理器
|
// ProxyHandler 统一代理处理器
|
||||||
@@ -29,14 +33,14 @@ type ProxyHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewProxyHandler 创建统一代理处理器
|
// NewProxyHandler 创建统一代理处理器
|
||||||
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService) *ProxyHandler {
|
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService, logger *zap.Logger) *ProxyHandler {
|
||||||
return &ProxyHandler{
|
return &ProxyHandler{
|
||||||
engine: engine,
|
engine: engine,
|
||||||
client: client,
|
client: client,
|
||||||
routingService: routingService,
|
routingService: routingService,
|
||||||
providerService: providerService,
|
providerService: providerService,
|
||||||
statsService: statsService,
|
statsService: statsService,
|
||||||
logger: zap.L(),
|
logger: pkglogger.WithModule(logger, "handler.proxy"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -45,7 +49,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
|||||||
// 从 URL 提取 clientProtocol: /{protocol}/v1/...
|
// 从 URL 提取 clientProtocol: /{protocol}/v1/...
|
||||||
clientProtocol := c.Param("protocol")
|
clientProtocol := c.Param("protocol")
|
||||||
if clientProtocol == "" {
|
if clientProtocol == "" {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"})
|
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "缺少协议前缀")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,12 +59,13 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
|||||||
path = "/" + path
|
path = "/" + path
|
||||||
}
|
}
|
||||||
nativePath := path
|
nativePath := path
|
||||||
|
requestPath := appendRawQuery(nativePath, c.Request.URL.RawQuery)
|
||||||
|
|
||||||
// 获取 client adapter
|
// 获取 client adapter
|
||||||
registry := h.engine.GetRegistry()
|
registry := h.engine.GetRegistry()
|
||||||
clientAdapter, err := registry.Get(clientProtocol)
|
clientAdapter, err := registry.Get(clientProtocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
|
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,7 +82,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
|||||||
if ifaceType == conversion.InterfaceTypeModelInfo {
|
if ifaceType == conversion.InterfaceTypeModelInfo {
|
||||||
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
|
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的模型 ID 格式"})
|
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的模型 ID 格式")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.handleModelInfo(c, unifiedID, clientAdapter)
|
h.handleModelInfo(c, unifiedID, clientAdapter)
|
||||||
@@ -87,40 +92,50 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
|||||||
// 读取请求体
|
// 读取请求体
|
||||||
body, err := io.ReadAll(c.Request.Body)
|
body, err := io.ReadAll(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
|
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "读取请求体失败")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析统一模型 ID(使用 adapter.ExtractModelName)
|
|
||||||
var providerID, modelName string
|
|
||||||
if len(body) > 0 {
|
|
||||||
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
|
|
||||||
if err == nil && unifiedID != "" {
|
|
||||||
pid, mn, err := modelid.ParseUnifiedModelID(unifiedID)
|
|
||||||
if err == nil {
|
|
||||||
providerID = pid
|
|
||||||
modelName = mn
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建输入 HTTPRequestSpec
|
// 构建输入 HTTPRequestSpec
|
||||||
inSpec := conversion.HTTPRequestSpec{
|
inSpec := conversion.HTTPRequestSpec{
|
||||||
URL: nativePath,
|
URL: requestPath,
|
||||||
Method: c.Request.Method,
|
Method: c.Request.Method,
|
||||||
Headers: extractHeaders(c),
|
Headers: extractHeaders(c),
|
||||||
Body: body,
|
Body: body,
|
||||||
}
|
}
|
||||||
|
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
|
||||||
|
|
||||||
|
// 只有 adapter 明确适配的接口才提取 model。未知接口不做通用 model 猜测。
|
||||||
|
if len(body) == 0 || !supportsModelExtraction(ifaceType) {
|
||||||
|
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
|
||||||
|
if err != nil {
|
||||||
|
if isInvalidJSONError(err) {
|
||||||
|
h.writeProxyError(c, http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||||
|
if err != nil {
|
||||||
|
// 原始模型名兼容透传:非统一模型 ID 不参与路由。
|
||||||
|
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if providerID == "" || modelName == "" {
|
||||||
|
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 路由
|
// 路由
|
||||||
routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
|
routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// GET 请求或无法提取 model 时,直接转发到上游
|
h.writeRouteError(c, err)
|
||||||
if len(body) == 0 || modelName == "" {
|
|
||||||
h.forwardPassthrough(c, inSpec, clientProtocol)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.writeError(c, err, clientProtocol)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,9 +155,6 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
|||||||
routeResult.Model.ModelName, // 上游模型名,用于请求改写
|
routeResult.Model.ModelName, // 上游模型名,用于请求改写
|
||||||
)
|
)
|
||||||
|
|
||||||
// 判断是否流式
|
|
||||||
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
|
|
||||||
|
|
||||||
// 计算统一模型 ID(用于响应覆写)
|
// 计算统一模型 ID(用于响应覆写)
|
||||||
unifiedModelID := routeResult.Model.UnifiedModelID()
|
unifiedModelID := routeResult.Model.UnifiedModelID()
|
||||||
|
|
||||||
@@ -153,12 +165,34 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func supportsModelExtraction(ifaceType conversion.InterfaceType) bool {
|
||||||
|
switch ifaceType {
|
||||||
|
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isInvalidJSONError(err error) bool {
|
||||||
|
var syntaxErr *json.SyntaxError
|
||||||
|
var typeErr *json.UnmarshalTypeError
|
||||||
|
return errors.As(err, &syntaxErr) || errors.As(err, &typeErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendRawQuery(path, rawQuery string) string {
|
||||||
|
if rawQuery == "" {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
return path + "?" + rawQuery
|
||||||
|
}
|
||||||
|
|
||||||
// handleNonStream 处理非流式请求
|
// handleNonStream 处理非流式请求
|
||||||
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
|
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
|
||||||
// 转换请求
|
// 转换请求
|
||||||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("转换请求失败", zap.String("error", err.Error()))
|
h.logger.Error("转换请求失败", zap.Error(err))
|
||||||
h.writeConversionError(c, err, clientProtocol)
|
h.writeConversionError(c, err, clientProtocol)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -166,31 +200,27 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
|
|||||||
// 发送请求
|
// 发送请求
|
||||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("发送请求失败", zap.String("error", err.Error()))
|
h.logger.Error("发送请求失败", zap.Error(err))
|
||||||
h.writeConversionError(c, err, clientProtocol)
|
h.writeUpstreamUnavailable(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
h.writeUpstreamResponse(c, *resp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 转换响应,传入 modelOverride(跨协议场景覆写 model 字段)
|
// 转换响应,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
|
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
|
h.logger.Error("转换响应失败", zap.Error(err))
|
||||||
h.writeConversionError(c, err, clientProtocol)
|
h.writeConversionError(c, err, clientProtocol)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置响应头
|
h.writeConvertedResponse(c, *convertedResp)
|
||||||
for k, v := range convertedResp.Headers {
|
|
||||||
c.Header(k, v)
|
|
||||||
}
|
|
||||||
if c.GetHeader("Content-Type") == "" {
|
|
||||||
c.Header("Content-Type", "application/json")
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
|
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -203,15 +233,23 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建流式转换器,传入 modelOverride(跨协议场景覆写 model 字段)
|
// 发送流式请求
|
||||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
|
streamResp, err := h.client.SendStream(c.Request.Context(), *outSpec)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.writeConversionError(c, err, clientProtocol)
|
h.writeUpstreamUnavailable(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
|
||||||
|
StatusCode: streamResp.StatusCode,
|
||||||
|
Headers: streamResp.Headers,
|
||||||
|
Body: streamResp.Body,
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 发送流式请求
|
// 创建流式转换器,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||||
eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec)
|
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.writeConversionError(c, err, clientProtocol)
|
h.writeConversionError(c, err, clientProtocol)
|
||||||
return
|
return
|
||||||
@@ -222,37 +260,61 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
|
|||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
|
|
||||||
writer := bufio.NewWriter(c.Writer)
|
writer := bufio.NewWriter(c.Writer)
|
||||||
|
flushed := false
|
||||||
|
|
||||||
for event := range eventChan {
|
for event := range streamResp.Events {
|
||||||
if event.Error != nil {
|
if event.Error != nil {
|
||||||
h.logger.Error("流读取错误", zap.String("error", event.Error.Error()))
|
h.logger.Error("流读取错误", zap.Error(event.Error))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if event.Done {
|
if event.Done {
|
||||||
// flush 转换器
|
// flush 转换器
|
||||||
chunks := streamConverter.Flush()
|
chunks := streamConverter.Flush()
|
||||||
for _, chunk := range chunks {
|
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||||
writer.Write(chunk)
|
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||||
writer.Flush()
|
|
||||||
}
|
}
|
||||||
|
flushed = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
chunks := streamConverter.ProcessChunk(event.Data)
|
chunks := streamConverter.ProcessChunk(event.Data)
|
||||||
for _, chunk := range chunks {
|
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||||
writer.Write(chunk)
|
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||||
writer.Flush()
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !flushed {
|
||||||
|
chunks := streamConverter.Flush()
|
||||||
|
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||||
|
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
|
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error {
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
if _, err := writer.Write(chunk); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := writer.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// isStreamRequest 判断是否流式请求
|
// isStreamRequest 判断是否流式请求
|
||||||
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
|
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
|
||||||
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol)
|
ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
if ifaceType != conversion.InterfaceTypeChat {
|
if ifaceType != conversion.InterfaceTypeChat {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -271,8 +333,8 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
|
|||||||
// 从数据库查询所有启用的模型
|
// 从数据库查询所有启用的模型
|
||||||
models, err := h.providerService.ListEnabledModels()
|
models, err := h.providerService.ListEnabledModels()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("查询启用模型失败", zap.String("error", err.Error()))
|
h.logger.Error("查询启用模型失败", zap.Error(err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
|
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "查询模型失败")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -293,8 +355,8 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
|
|||||||
// 使用 adapter 编码返回
|
// 使用 adapter 编码返回
|
||||||
body, err := adapter.EncodeModelsResponse(modelList)
|
body, err := adapter.EncodeModelsResponse(modelList)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error()))
|
h.logger.Error("编码 Models 响应失败", zap.Error(err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
|
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -306,17 +368,14 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
|
|||||||
// 解析统一模型 ID
|
// 解析统一模型 ID
|
||||||
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
|
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的统一模型 ID 格式")
|
||||||
"error": "无效的统一模型 ID 格式",
|
|
||||||
"code": "INVALID_MODEL_ID",
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 从数据库查询模型
|
// 从数据库查询模型
|
||||||
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
|
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "模型未找到"})
|
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", "模型未找到")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -331,42 +390,104 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
|
|||||||
// 使用 adapter 编码返回
|
// 使用 adapter 编码返回
|
||||||
body, err := adapter.EncodeModelInfoResponse(modelInfo)
|
body, err := adapter.EncodeModelInfoResponse(modelInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("编码 ModelInfo 响应失败", zap.String("error", err.Error()))
|
h.logger.Error("编码 ModelInfo 响应失败", zap.Error(err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
|
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Data(http.StatusOK, "application/json", body)
|
c.Data(http.StatusOK, "application/json", body)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeConversionError 写入转换错误
|
// writeConversionError 写入网关层转换错误
|
||||||
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
|
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
|
||||||
if convErr, ok := err.(*conversion.ConversionError); ok {
|
var convErr *conversion.ConversionError
|
||||||
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol)
|
if errors.As(err, &convErr) {
|
||||||
c.Data(statusCode, "application/json", body)
|
statusCode, code, message := mapConversionError(convErr)
|
||||||
|
h.writeProxyError(c, statusCode, code, message)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeError 写入路由错误
|
func mapConversionError(err *conversion.ConversionError) (int, string, string) {
|
||||||
func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) {
|
switch err.Code {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
case conversion.ErrorCodeJSONParseError:
|
||||||
|
if phase, ok := err.Details[conversion.ErrorDetailPhase].(string); ok && phase == conversion.ErrorPhaseRequest {
|
||||||
|
return http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误"
|
||||||
|
}
|
||||||
|
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
|
||||||
|
case conversion.ErrorCodeInvalidInput,
|
||||||
|
conversion.ErrorCodeMissingRequiredField,
|
||||||
|
conversion.ErrorCodeProtocolConstraint:
|
||||||
|
return http.StatusBadRequest, "INVALID_REQUEST", err.Message
|
||||||
|
case conversion.ErrorCodeInterfaceNotSupported:
|
||||||
|
return http.StatusBadRequest, "UNSUPPORTED_INTERFACE", err.Message
|
||||||
|
case conversion.ErrorCodeUnsupportedMultimodal:
|
||||||
|
return http.StatusBadRequest, "UNSUPPORTED_MULTIMODAL", err.Message
|
||||||
|
default:
|
||||||
|
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ProxyHandler) writeRouteError(c *gin.Context, err error) {
|
||||||
|
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||||
|
switch appErr.Code {
|
||||||
|
case appErrors.ErrModelNotFound.Code, appErrors.ErrModelDisabled.Code:
|
||||||
|
h.writeProxyError(c, appErr.HTTPStatus, "MODEL_NOT_FOUND", appErr.Message)
|
||||||
|
case appErrors.ErrProviderNotFound.Code, appErrors.ErrProviderDisabled.Code:
|
||||||
|
h.writeProxyError(c, appErr.HTTPStatus, "PROVIDER_NOT_FOUND", appErr.Message)
|
||||||
|
default:
|
||||||
|
h.writeProxyError(c, appErr.HTTPStatus, "INVALID_REQUEST", appErr.Message)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ProxyHandler) writeUpstreamUnavailable(c *gin.Context, err error) {
|
||||||
|
h.logger.Error("上游不可达", zap.Error(err))
|
||||||
|
h.writeProxyError(c, http.StatusBadGateway, "UPSTREAM_UNAVAILABLE", "上游服务不可达")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ProxyHandler) writeProxyError(c *gin.Context, status int, code, message string) {
|
||||||
|
c.JSON(status, gin.H{
|
||||||
|
"error": message,
|
||||||
|
"code": code,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ProxyHandler) writeConvertedResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
|
||||||
|
for k, v := range resp.Headers {
|
||||||
|
c.Header(k, v)
|
||||||
|
}
|
||||||
|
contentType := headerValue(resp.Headers, "Content-Type")
|
||||||
|
if contentType == "" {
|
||||||
|
contentType = "application/json"
|
||||||
|
}
|
||||||
|
c.Data(resp.StatusCode, contentType, resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ProxyHandler) writeUpstreamResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
|
||||||
|
for k, v := range filterHopByHopHeaders(resp.Headers) {
|
||||||
|
c.Header(k, v)
|
||||||
|
}
|
||||||
|
contentType := headerValue(resp.Headers, "Content-Type")
|
||||||
|
c.Data(resp.StatusCode, contentType, resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
|
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
|
||||||
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) {
|
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string, ifaceType conversion.InterfaceType, isStream bool) {
|
||||||
registry := h.engine.GetRegistry()
|
registry := h.engine.GetRegistry()
|
||||||
adapter, err := registry.Get(clientProtocol)
|
adapter, err := registry.Get(clientProtocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
|
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
providers, err := h.providerService.List()
|
providers, err := h.providerService.List()
|
||||||
if err != nil || len(providers) == 0 {
|
if err != nil || len(providers) == 0 {
|
||||||
h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL))
|
h.logger.Warn("无可用供应商转发请求", zap.String("path", inSpec.URL))
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"})
|
h.writeProxyError(c, http.StatusNotFound, "PROVIDER_NOT_FOUND", "没有可用的供应商。请先创建供应商和模型。")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -376,19 +497,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
|
|||||||
providerProtocol = "openai"
|
providerProtocol = "openai"
|
||||||
}
|
}
|
||||||
|
|
||||||
ifaceType := adapter.DetectInterfaceType(inSpec.URL)
|
|
||||||
|
|
||||||
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
|
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
|
||||||
|
|
||||||
var outSpec *conversion.HTTPRequestSpec
|
var outSpec *conversion.HTTPRequestSpec
|
||||||
if clientProtocol == providerProtocol {
|
if clientProtocol == providerProtocol {
|
||||||
upstreamURL := p.BaseURL + inSpec.URL
|
upstreamPath := adapter.BuildUrl(stripRawQuery(inSpec.URL), ifaceType)
|
||||||
|
upstreamPath = appendRawQuery(upstreamPath, rawQueryFromPath(inSpec.URL))
|
||||||
headers := adapter.BuildHeaders(targetProvider)
|
headers := adapter.BuildHeaders(targetProvider)
|
||||||
if _, ok := headers["Content-Type"]; !ok {
|
if _, ok := headers["Content-Type"]; !ok {
|
||||||
headers["Content-Type"] = "application/json"
|
headers["Content-Type"] = "application/json"
|
||||||
}
|
}
|
||||||
outSpec = &conversion.HTTPRequestSpec{
|
outSpec = &conversion.HTTPRequestSpec{
|
||||||
URL: upstreamURL,
|
URL: joinBaseURL(p.BaseURL, upstreamPath),
|
||||||
Method: inSpec.Method,
|
Method: inSpec.Method,
|
||||||
Headers: headers,
|
Headers: headers,
|
||||||
Body: inSpec.Body,
|
Body: inSpec.Body,
|
||||||
@@ -401,9 +521,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isStream {
|
||||||
|
h.forwardStream(c, *outSpec, clientProtocol, providerProtocol, ifaceType)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.writeConversionError(c, err, clientProtocol)
|
h.writeUpstreamUnavailable(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
h.writeUpstreamResponse(c, *resp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -413,13 +542,111 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range convertedResp.Headers {
|
h.writeConvertedResponse(c, *convertedResp)
|
||||||
c.Header(k, v)
|
|
||||||
}
|
}
|
||||||
if c.GetHeader("Content-Type") == "" {
|
|
||||||
c.Header("Content-Type", "application/json")
|
func (h *ProxyHandler) forwardStream(c *gin.Context, outSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, ifaceType conversion.InterfaceType) {
|
||||||
|
streamResp, err := h.client.SendStream(c.Request.Context(), outSpec)
|
||||||
|
if err != nil {
|
||||||
|
h.writeUpstreamUnavailable(c, err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
|
||||||
|
StatusCode: streamResp.StatusCode,
|
||||||
|
Headers: streamResp.Headers,
|
||||||
|
Body: streamResp.Body,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, "", ifaceType)
|
||||||
|
if err != nil {
|
||||||
|
h.writeConversionError(c, err, clientProtocol)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
|
||||||
|
writer := bufio.NewWriter(c.Writer)
|
||||||
|
flushed := false
|
||||||
|
for event := range streamResp.Events {
|
||||||
|
if event.Error != nil {
|
||||||
|
h.logger.Error("透传流读取错误", zap.Error(event.Error))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if event.Done {
|
||||||
|
chunks := streamConverter.Flush()
|
||||||
|
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||||
|
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
flushed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
chunks := streamConverter.ProcessChunk(event.Data)
|
||||||
|
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||||
|
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !flushed {
|
||||||
|
chunks := streamConverter.Flush()
|
||||||
|
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||||
|
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripRawQuery(path string) string {
|
||||||
|
pathOnly, _, _ := strings.Cut(path, "?")
|
||||||
|
return pathOnly
|
||||||
|
}
|
||||||
|
|
||||||
|
func rawQueryFromPath(path string) string {
|
||||||
|
_, rawQuery, found := strings.Cut(path, "?")
|
||||||
|
if !found {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return rawQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinBaseURL(baseURL, path string) string {
|
||||||
|
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
func headerValue(headers map[string]string, key string) string {
|
||||||
|
for k, v := range headers {
|
||||||
|
if strings.EqualFold(k, key) {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterHopByHopHeaders(headers map[string]string) map[string]string {
|
||||||
|
if len(headers) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
hopByHop := map[string]struct{}{
|
||||||
|
"connection": {},
|
||||||
|
"transfer-encoding": {},
|
||||||
|
"keep-alive": {},
|
||||||
|
"proxy-authenticate": {},
|
||||||
|
"proxy-authorization": {},
|
||||||
|
"te": {},
|
||||||
|
"trailer": {},
|
||||||
|
"upgrade": {},
|
||||||
|
}
|
||||||
|
filtered := make(map[string]string, len(headers))
|
||||||
|
for k, v := range headers {
|
||||||
|
if _, skip := hopByHop[strings.ToLower(k)]; skip {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered[k] = v
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractHeaders 从 Gin context 提取请求头
|
// extractHeaders 从 Gin context 提取请求头
|
||||||
|
|||||||
@@ -5,33 +5,34 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"go.uber.org/mock/gomock"
|
|
||||||
|
|
||||||
"nex/backend/internal/conversion"
|
"nex/backend/internal/conversion"
|
||||||
"nex/backend/internal/conversion/anthropic"
|
"nex/backend/internal/conversion/anthropic"
|
||||||
"nex/backend/internal/conversion/openai"
|
"nex/backend/internal/conversion/openai"
|
||||||
"nex/backend/internal/domain"
|
"nex/backend/internal/domain"
|
||||||
"nex/backend/internal/provider"
|
"nex/backend/internal/provider"
|
||||||
appErrors "nex/backend/pkg/errors"
|
|
||||||
"nex/backend/tests/mocks"
|
"nex/backend/tests/mocks"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/mock/gomock"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
appErrors "nex/backend/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
func setupProxyEngine(t *testing.T) *conversion.ConversionEngine {
|
func setupProxyEngine(t *testing.T) *conversion.ConversionEngine {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
registry := conversion.NewMemoryRegistry()
|
registry := conversion.NewMemoryRegistry()
|
||||||
engine := conversion.NewConversionEngine(registry, nil)
|
engine := conversion.NewConversionEngine(registry, zap.NewNop())
|
||||||
require.NoError(t, registry.Register(openai.NewAdapter()))
|
require.NoError(t, registry.Register(openai.NewAdapter()))
|
||||||
require.NoError(t, registry.Register(anthropic.NewAdapter()))
|
require.NoError(t, registry.Register(anthropic.NewAdapter()))
|
||||||
return engine
|
return engine
|
||||||
@@ -44,6 +45,7 @@ func newTestProxyHandler(engine *conversion.ConversionEngine, client *mocks.Mock
|
|||||||
routingSvc,
|
routingSvc,
|
||||||
providerSvc,
|
providerSvc,
|
||||||
statsSvc,
|
statsSvc,
|
||||||
|
zap.NewNop(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,7 +74,7 @@ func TestProxyHandler_HandleProxy_NonStreamSuccess(t *testing.T) {
|
|||||||
|
|
||||||
engine := setupProxyEngine(t)
|
engine := setupProxyEngine(t)
|
||||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||||
}, nil)
|
}, nil)
|
||||||
@@ -91,8 +93,8 @@ func TestProxyHandler_HandleProxy_NonStreamSuccess(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -108,20 +110,20 @@ func TestProxyHandler_HandleProxy_RoutingError_WithBody(t *testing.T) {
|
|||||||
|
|
||||||
engine := setupProxyEngine(t)
|
engine := setupProxyEngine(t)
|
||||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
routingSvc.EXPECT().RouteByModelName("", "").Return(nil, appErrors.ErrModelNotFound)
|
routingSvc.EXPECT().RouteByModelName("unknown", "model").Return(nil, appErrors.ErrModelNotFound)
|
||||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||||
providerSvc.EXPECT().List().Return(nil, nil)
|
|
||||||
client := mocks.NewMockProviderClient(ctrl)
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"unknown","messages":[{"role":"user","content":"hi"}]}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 404, w.Code)
|
assert.Equal(t, 404, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "MODEL_NOT_FOUND")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
|
func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
|
||||||
@@ -130,7 +132,7 @@ func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
|
|||||||
|
|
||||||
engine := setupProxyEngine(t)
|
engine := setupProxyEngine(t)
|
||||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||||
}, nil)
|
}, nil)
|
||||||
@@ -144,11 +146,12 @@ func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 500, w.Code)
|
assert.Equal(t, 502, w.Code)
|
||||||
|
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
|
func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
|
||||||
@@ -157,7 +160,7 @@ func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
|
|||||||
|
|
||||||
engine := setupProxyEngine(t)
|
engine := setupProxyEngine(t)
|
||||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||||
}, nil)
|
}, nil)
|
||||||
@@ -171,11 +174,12 @@ func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 500, w.Code)
|
assert.Equal(t, 502, w.Code)
|
||||||
|
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
|
func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
|
||||||
@@ -184,12 +188,12 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
|
|||||||
|
|
||||||
engine := setupProxyEngine(t)
|
engine := setupProxyEngine(t)
|
||||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||||
}, nil)
|
}, nil)
|
||||||
client := mocks.NewMockProviderClient(ctrl)
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||||
ch := make(chan provider.StreamEvent, 10)
|
ch := make(chan provider.StreamEvent, 10)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
@@ -198,7 +202,7 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
|
|||||||
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
|
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
|
||||||
ch <- provider.StreamEvent{Done: true}
|
ch <- provider.StreamEvent{Done: true}
|
||||||
}()
|
}()
|
||||||
return ch, nil
|
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
|
||||||
})
|
})
|
||||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||||
@@ -207,13 +211,14 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
|
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
|
||||||
assert.Contains(t, w.Body.String(), "Hello")
|
assert.Contains(t, w.Body.String(), "Hello")
|
||||||
|
assert.Contains(t, w.Body.String(), "p1/gpt-4")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
|
func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
|
||||||
@@ -222,12 +227,12 @@ func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
|
|||||||
|
|
||||||
engine := setupProxyEngine(t)
|
engine := setupProxyEngine(t)
|
||||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||||
}, nil)
|
}, nil)
|
||||||
client := mocks.NewMockProviderClient(ctrl)
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||||
return nil, context.DeadlineExceeded
|
return nil, context.DeadlineExceeded
|
||||||
})
|
})
|
||||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||||
@@ -236,11 +241,12 @@ func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 500, w.Code)
|
assert.Equal(t, 502, w.Code)
|
||||||
|
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) {
|
func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) {
|
||||||
@@ -260,8 +266,8 @@ func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
|
||||||
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
|
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -281,11 +287,11 @@ func TestProxyHandler_ForwardPassthrough_UnsupportedProtocol(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "unknown"}, {Key: "path", Value: "/models"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "unknown"}, {Key: "path", Value: "/v1/models"}}
|
||||||
c.Request = httptest.NewRequest("GET", "/unknown/models", nil)
|
c.Request = httptest.NewRequest("GET", "/unknown/models", nil)
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 400, w.Code)
|
assert.Equal(t, 404, w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
|
func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
|
||||||
@@ -303,8 +309,8 @@ func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
|
||||||
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
|
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -328,7 +334,7 @@ func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) {
|
|||||||
|
|
||||||
engine := setupProxyEngine(t)
|
engine := setupProxyEngine(t)
|
||||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "", Enabled: true},
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "", Enabled: true},
|
||||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||||
}, nil)
|
}, nil)
|
||||||
@@ -346,8 +352,8 @@ func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -370,6 +376,7 @@ func TestProxyHandler_WriteConversionError_NonConversionError(t *testing.T) {
|
|||||||
|
|
||||||
h.writeConversionError(c, context.DeadlineExceeded, "openai")
|
h.writeConversionError(c, context.DeadlineExceeded, "openai")
|
||||||
assert.Equal(t, 500, w.Code)
|
assert.Equal(t, 500, w.Code)
|
||||||
|
assert.JSONEq(t, `{"error":"context deadline exceeded","code":"CONVERSION_FAILED"}`, w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
|
func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
|
||||||
@@ -389,7 +396,40 @@ func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
|
|||||||
|
|
||||||
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "bad request")
|
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "bad request")
|
||||||
h.writeConversionError(c, convErr, "openai")
|
h.writeConversionError(c, convErr, "openai")
|
||||||
assert.Equal(t, 500, w.Code)
|
assert.Equal(t, 400, w.Code)
|
||||||
|
assert.JSONEq(t, `{"error":"bad request","code":"INVALID_REQUEST"}`, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyHandler_WriteConversionError_JSONPhase(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
|
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||||
|
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||||
|
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||||
|
|
||||||
|
t.Run("request json parse error", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||||
|
|
||||||
|
h.writeConversionError(c, conversion.NewRequestJSONParseError("解码请求失败", context.Canceled), "openai")
|
||||||
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
assert.JSONEq(t, `{"error":"请求体 JSON 格式错误","code":"INVALID_JSON"}`, w.Body.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("response json parse error", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||||
|
|
||||||
|
h.writeConversionError(c, conversion.NewResponseJSONParseError("解码响应失败", context.Canceled), "openai")
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||||
|
assert.JSONEq(t, `{"error":"解码响应失败","code":"CONVERSION_FAILED"}`, w.Body.String())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) {
|
func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) {
|
||||||
@@ -409,8 +449,8 @@ func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
|
||||||
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
|
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -422,19 +462,19 @@ func TestProxyHandler_HandleStream_MidStreamError(t *testing.T) {
|
|||||||
|
|
||||||
engine := setupProxyEngine(t)
|
engine := setupProxyEngine(t)
|
||||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||||
}, nil)
|
}, nil)
|
||||||
client := mocks.NewMockProviderClient(ctrl)
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||||
ch := make(chan provider.StreamEvent, 10)
|
ch := make(chan provider.StreamEvent, 10)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
ch <- provider.StreamEvent{Data: []byte("data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n")}
|
ch <- provider.StreamEvent{Data: []byte("data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n")}
|
||||||
ch <- provider.StreamEvent{Error: fmt.Errorf("connection reset by peer")}
|
ch <- provider.StreamEvent{Error: fmt.Errorf("connection reset by peer")}
|
||||||
}()
|
}()
|
||||||
return ch, nil
|
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
|
||||||
})
|
})
|
||||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||||
@@ -443,8 +483,8 @@ func TestProxyHandler_HandleStream_MidStreamError(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -459,12 +499,12 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
|
|||||||
|
|
||||||
engine := setupProxyEngine(t)
|
engine := setupProxyEngine(t)
|
||||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||||
}, nil)
|
}, nil)
|
||||||
client := mocks.NewMockProviderClient(ctrl)
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||||
ch := make(chan provider.StreamEvent, 10)
|
ch := make(chan provider.StreamEvent, 10)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
@@ -472,7 +512,7 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
|
|||||||
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
|
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
|
||||||
ch <- provider.StreamEvent{Done: true}
|
ch <- provider.StreamEvent{Done: true}
|
||||||
}()
|
}()
|
||||||
return ch, nil
|
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
|
||||||
})
|
})
|
||||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||||
@@ -481,8 +521,8 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -499,12 +539,12 @@ func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) {
|
|||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|
||||||
registry := conversion.NewMemoryRegistry()
|
registry := conversion.NewMemoryRegistry()
|
||||||
engine := conversion.NewConversionEngine(registry, nil)
|
engine := conversion.NewConversionEngine(registry, zap.NewNop())
|
||||||
err := registry.Register(openai.NewAdapter())
|
err := registry.Register(openai.NewAdapter())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
|
||||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||||
}, nil)
|
}, nil)
|
||||||
@@ -515,8 +555,8 @@ func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 500, w.Code)
|
assert.Equal(t, 500, w.Code)
|
||||||
@@ -527,11 +567,11 @@ func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) {
|
|||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|
||||||
registry := conversion.NewMemoryRegistry()
|
registry := conversion.NewMemoryRegistry()
|
||||||
engine := conversion.NewConversionEngine(registry, nil)
|
engine := conversion.NewConversionEngine(registry, zap.NewNop())
|
||||||
require.NoError(t, registry.Register(openai.NewAdapter()))
|
require.NoError(t, registry.Register(openai.NewAdapter()))
|
||||||
|
|
||||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
|
||||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||||
}, nil)
|
}, nil)
|
||||||
@@ -542,8 +582,8 @@ func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 500, w.Code)
|
assert.Equal(t, 500, w.Code)
|
||||||
@@ -554,12 +594,12 @@ func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) {
|
|||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|
||||||
registry := conversion.NewMemoryRegistry()
|
registry := conversion.NewMemoryRegistry()
|
||||||
engine := conversion.NewConversionEngine(registry, nil)
|
engine := conversion.NewConversionEngine(registry, zap.NewNop())
|
||||||
require.NoError(t, registry.Register(openai.NewAdapter()))
|
require.NoError(t, registry.Register(openai.NewAdapter()))
|
||||||
require.NoError(t, registry.Register(anthropic.NewAdapter()))
|
require.NoError(t, registry.Register(anthropic.NewAdapter()))
|
||||||
|
|
||||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "anthropic", Enabled: true},
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "anthropic", Enabled: true},
|
||||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "claude-3", Enabled: true},
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "claude-3", Enabled: true},
|
||||||
}, nil)
|
}, nil)
|
||||||
@@ -577,8 +617,8 @@ func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"claude-3","messages":[{"role":"user","content":"hi"}]}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 500, w.Code)
|
assert.Equal(t, 500, w.Code)
|
||||||
@@ -590,7 +630,7 @@ func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) {
|
|||||||
|
|
||||||
engine := setupProxyEngine(t)
|
engine := setupProxyEngine(t)
|
||||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||||
}, nil)
|
}, nil)
|
||||||
@@ -609,8 +649,8 @@ func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -623,7 +663,7 @@ func TestProxyHandler_ForwardPassthrough_CrossProtocol(t *testing.T) {
|
|||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|
||||||
registry := conversion.NewMemoryRegistry()
|
registry := conversion.NewMemoryRegistry()
|
||||||
engine := conversion.NewConversionEngine(registry, nil)
|
engine := conversion.NewConversionEngine(registry, zap.NewNop())
|
||||||
require.NoError(t, registry.Register(openai.NewAdapter()))
|
require.NoError(t, registry.Register(openai.NewAdapter()))
|
||||||
|
|
||||||
anthropicAdapter := anthropic.NewAdapter()
|
anthropicAdapter := anthropic.NewAdapter()
|
||||||
@@ -641,8 +681,8 @@ func TestProxyHandler_ForwardPassthrough_CrossProtocol(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
|
||||||
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
|
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -665,8 +705,8 @@ func TestProxyHandler_ForwardPassthrough_NoBody_NoModel(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
|
||||||
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
|
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -689,10 +729,10 @@ func TestIsStreamRequest_EdgeCases(t *testing.T) {
|
|||||||
path string
|
path string
|
||||||
expected bool
|
expected bool
|
||||||
}{
|
}{
|
||||||
{"stream at end of JSON", `{"messages":[],"stream":true}`, "/chat/completions", true},
|
{"stream at end of JSON", `{"messages":[],"stream":true}`, "/v1/chat/completions", true},
|
||||||
{"stream with spaces", `{"stream" : true}`, "/chat/completions", true},
|
{"stream with spaces", `{"stream" : true}`, "/v1/chat/completions", true},
|
||||||
{"stream embedded in string value", `{"model":"stream:true"}`, "/chat/completions", false},
|
{"stream embedded in string value", `{"model":"stream:true"}`, "/v1/chat/completions", false},
|
||||||
{"empty body", "", "/chat/completions", false},
|
{"empty body", "", "/v1/chat/completions", false},
|
||||||
{"stream true embeddings", `{"model":"text-emb","stream":true}`, "/v1/embeddings", false},
|
{"stream true embeddings", `{"model":"text-emb","stream":true}`, "/v1/embeddings", false},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -719,8 +759,9 @@ func TestProxyHandler_WriteError_RouteError(t *testing.T) {
|
|||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||||
|
|
||||||
h.writeError(c, fmt.Errorf("model not found"), "openai")
|
h.writeRouteError(c, fmt.Errorf("model not found"))
|
||||||
assert.Equal(t, 404, w.Code)
|
assert.Equal(t, 404, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "MODEL_NOT_FOUND")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
|
func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
|
||||||
@@ -740,8 +781,8 @@ func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
|
||||||
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
|
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -764,35 +805,35 @@ func TestIsStreamRequest(t *testing.T) {
|
|||||||
name: "stream true",
|
name: "stream true",
|
||||||
body: []byte(`{"model": "gpt-4", "stream": true}`),
|
body: []byte(`{"model": "gpt-4", "stream": true}`),
|
||||||
clientProtocol: "openai",
|
clientProtocol: "openai",
|
||||||
nativePath: "/chat/completions",
|
nativePath: "/v1/chat/completions",
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "stream false",
|
name: "stream false",
|
||||||
body: []byte(`{"model": "gpt-4", "stream": false}`),
|
body: []byte(`{"model": "gpt-4", "stream": false}`),
|
||||||
clientProtocol: "openai",
|
clientProtocol: "openai",
|
||||||
nativePath: "/chat/completions",
|
nativePath: "/v1/chat/completions",
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no stream field",
|
name: "no stream field",
|
||||||
body: []byte(`{"model": "gpt-4"}`),
|
body: []byte(`{"model": "gpt-4"}`),
|
||||||
clientProtocol: "openai",
|
clientProtocol: "openai",
|
||||||
nativePath: "/chat/completions",
|
nativePath: "/v1/chat/completions",
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid json",
|
name: "invalid json",
|
||||||
body: []byte(`{invalid}`),
|
body: []byte(`{invalid}`),
|
||||||
clientProtocol: "openai",
|
clientProtocol: "openai",
|
||||||
nativePath: "/chat/completions",
|
nativePath: "/v1/chat/completions",
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "not chat endpoint",
|
name: "not chat endpoint",
|
||||||
body: []byte(`{"model": "gpt-4", "stream": true}`),
|
body: []byte(`{"model": "gpt-4", "stream": true}`),
|
||||||
clientProtocol: "openai",
|
clientProtocol: "openai",
|
||||||
nativePath: "/models",
|
nativePath: "/v1/models",
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -830,8 +871,8 @@ func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
|
||||||
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
|
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -842,7 +883,8 @@ func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
|
|||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Len(t, data, 2)
|
assert.Len(t, data, 2)
|
||||||
|
|
||||||
first := data[0].(map[string]interface{})
|
first, ok2 := data[0].(map[string]interface{})
|
||||||
|
require.True(t, ok2)
|
||||||
assert.Equal(t, "openai/gpt-4", first["id"])
|
assert.Equal(t, "openai/gpt-4", first["id"])
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -860,8 +902,8 @@ func TestProxyHandler_HandleProxy_ModelInfo_LocalQuery(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/openai/gpt-4"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models/openai/gpt-4"}}
|
||||||
c.Request = httptest.NewRequest("GET", "/openai/models/openai/gpt-4", nil)
|
c.Request = httptest.NewRequest("GET", "/openai/v1/models/openai/gpt-4", nil)
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -894,8 +936,8 @@ func TestProxyHandler_HandleProxy_Models_EmptySuffix_ForwardPassthrough(t *testi
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models/"}}
|
||||||
c.Request = httptest.NewRequest("GET", "/openai/models/", nil)
|
c.Request = httptest.NewRequest("GET", "/openai/v1/models/", nil)
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -916,7 +958,7 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
|
|||||||
client := mocks.NewMockProviderClient(ctrl)
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||||
var req map[string]interface{}
|
var req map[string]interface{}
|
||||||
json.Unmarshal(spec.Body, &req)
|
require.NoError(t, json.Unmarshal(spec.Body, &req))
|
||||||
assert.Equal(t, "gpt-4", req["model"])
|
assert.Equal(t, "gpt-4", req["model"])
|
||||||
|
|
||||||
return &conversion.HTTPResponseSpec{
|
return &conversion.HTTPResponseSpec{
|
||||||
@@ -932,8 +974,8 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -970,8 +1012,8 @@ func TestProxyHandler_HandleProxy_CrossProtocol_NonStream_UnifiedID(t *testing.T
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}]}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}]}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -992,7 +1034,7 @@ func TestProxyHandler_HandleProxy_CrossProtocol_Stream_UnifiedID(t *testing.T) {
|
|||||||
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
|
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
|
||||||
}, nil)
|
}, nil)
|
||||||
client := mocks.NewMockProviderClient(ctrl)
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||||
ch := make(chan provider.StreamEvent, 10)
|
ch := make(chan provider.StreamEvent, 10)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
@@ -1010,7 +1052,7 @@ data: {"type":"message_stop"}
|
|||||||
`)}
|
`)}
|
||||||
ch <- provider.StreamEvent{Done: true}
|
ch <- provider.StreamEvent{Done: true}
|
||||||
}()
|
}()
|
||||||
return ch, nil
|
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
|
||||||
})
|
})
|
||||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||||
@@ -1019,8 +1061,8 @@ data: {"type":"message_stop"}
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -1057,8 +1099,8 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_Fidelity(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}],"custom_param":"should_be_preserved"}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}],"custom_param":"should_be_preserved"}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
@@ -1088,8 +1130,8 @@ func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
|
|||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 404, w.Code)
|
assert.Equal(t, 404, w.Code)
|
||||||
@@ -1098,3 +1140,314 @@ func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
|
|||||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
assert.Contains(t, resp, "error")
|
assert.Contains(t, resp, "error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyHandler_HandleProxy_OpenAIAndAnthropicNativePaths(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
protocol string
|
||||||
|
path string
|
||||||
|
requestPath string
|
||||||
|
baseURL string
|
||||||
|
expectedURL string
|
||||||
|
body string
|
||||||
|
responseBody string
|
||||||
|
responseModel string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "openai path keeps v1 after gateway prefix",
|
||||||
|
protocol: "openai",
|
||||||
|
path: "/v1/chat/completions",
|
||||||
|
requestPath: "/openai/v1/chat/completions",
|
||||||
|
baseURL: "https://api.test.com/v1",
|
||||||
|
expectedURL: "https://api.test.com/v1/chat/completions",
|
||||||
|
body: `{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`,
|
||||||
|
responseBody: `{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`,
|
||||||
|
responseModel: "p1/gpt-4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "anthropic path keeps v1 after gateway prefix",
|
||||||
|
protocol: "anthropic",
|
||||||
|
path: "/v1/messages",
|
||||||
|
requestPath: "/anthropic/v1/messages",
|
||||||
|
baseURL: "https://api.anthropic.test",
|
||||||
|
expectedURL: "https://api.anthropic.test/v1/messages",
|
||||||
|
body: `{"model":"p1/gpt-4","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`,
|
||||||
|
responseBody: `{"id":"msg-1","type":"message","role":"assistant","model":"gpt-4","content":[{"type":"text","text":"ok"}],"stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`,
|
||||||
|
responseModel: "p1/gpt-4",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: tt.baseURL, Protocol: tt.protocol, Enabled: true},
|
||||||
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||||
|
}, nil)
|
||||||
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
|
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||||
|
assert.Equal(t, tt.expectedURL, spec.URL)
|
||||||
|
return &conversion.HTTPResponseSpec{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Headers: map[string]string{"Content-Type": "application/json"},
|
||||||
|
Body: []byte(tt.responseBody),
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||||
|
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||||
|
statsSvc.EXPECT().Record(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||||
|
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: tt.protocol}, {Key: "path", Value: tt.path}}
|
||||||
|
c.Request = httptest.NewRequest("POST", tt.requestPath, bytes.NewReader([]byte(tt.body)))
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), tt.responseModel)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyHandler_UpstreamNon2xx_Passthrough(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
|
||||||
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||||
|
}, nil)
|
||||||
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
|
client.EXPECT().Send(gomock.Any(), gomock.Any()).Return(&conversion.HTTPResponseSpec{
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
Headers: map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-Upstream-Error": "rate-limit",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
|
},
|
||||||
|
Body: []byte(`{"error":{"message":"rate limited"}}`),
|
||||||
|
}, nil)
|
||||||
|
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||||
|
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||||
|
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||||
|
assert.JSONEq(t, `{"error":{"message":"rate limited"}}`, w.Body.String())
|
||||||
|
assert.Equal(t, "rate-limit", w.Header().Get("X-Upstream-Error"))
|
||||||
|
assert.Empty(t, w.Header().Get("Transfer-Encoding"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyHandler_StreamUpstreamNon2xx_Passthrough(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
|
||||||
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||||
|
}, nil)
|
||||||
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
|
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).Return(&provider.StreamResponse{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Headers: map[string]string{"Content-Type": "application/json", "Connection": "close"},
|
||||||
|
Body: []byte(`{"error":"upstream down"}`),
|
||||||
|
}, nil)
|
||||||
|
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||||
|
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||||
|
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
require.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||||
|
assert.JSONEq(t, `{"error":"upstream down"}`, w.Body.String())
|
||||||
|
assert.Empty(t, w.Header().Get("Connection"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterHopByHopHeaders(t *testing.T) {
|
||||||
|
filtered := filterHopByHopHeaders(map[string]string{
|
||||||
|
"Connection": "close",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
|
"Keep-Alive": "timeout=5",
|
||||||
|
"Proxy-Authenticate": "Basic",
|
||||||
|
"Proxy-Authorization": "Basic token",
|
||||||
|
"TE": "trailers",
|
||||||
|
"Trailer": "Expires",
|
||||||
|
"Upgrade": "websocket",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-Request-ID": "req-1",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-Request-ID": "req-1",
|
||||||
|
}, filtered)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyHandler_UnknownInterface_DoesNotGuessModel(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
|
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||||
|
providerSvc.EXPECT().List().Return([]domain.Provider{
|
||||||
|
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
|
||||||
|
}, nil)
|
||||||
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
|
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||||
|
assert.Equal(t, "https://api.test.com/v1/unknown?trace=1", spec.URL)
|
||||||
|
assert.JSONEq(t, `{"model":"p1/gpt-4","payload":true}`, string(spec.Body))
|
||||||
|
return &conversion.HTTPResponseSpec{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Headers: map[string]string{"Content-Type": "application/json"},
|
||||||
|
Body: []byte(`{"ok":true}`),
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||||
|
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/unknown"}}
|
||||||
|
c.Request = httptest.NewRequest("POST", "/openai/unknown?trace=1", bytes.NewReader([]byte(`{"model":"p1/gpt-4","payload":true}`)))
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.JSONEq(t, `{"ok":true}`, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyHandler_InvalidJSON_UsesGatewayError(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
|
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||||
|
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||||
|
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":`)))
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
assert.JSONEq(t, `{"error":"请求体 JSON 格式错误","code":"INVALID_JSON"}`, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyHandler_CrossProtocolMultimodal_Unsupported(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
|
routingSvc.EXPECT().RouteByModelName("anthropic_p", "claude").Return(&domain.RouteResult{
|
||||||
|
Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.test", Protocol: "anthropic", Enabled: true},
|
||||||
|
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude", Enabled: true},
|
||||||
|
}, nil)
|
||||||
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
|
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||||
|
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||||
|
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||||
|
|
||||||
|
body := []byte(`{"model":"anthropic_p/claude","messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body))
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "UNSUPPORTED_MULTIMODAL")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyHandler_SameProtocolMultimodal_SmartPassthrough(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
|
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||||
|
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
|
||||||
|
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||||
|
}, nil)
|
||||||
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
|
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||||
|
assert.Contains(t, string(spec.Body), "image_url")
|
||||||
|
assert.Contains(t, string(spec.Body), `"model":"gpt-4"`)
|
||||||
|
return &conversion.HTTPResponseSpec{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Headers: map[string]string{"Content-Type": "application/json"},
|
||||||
|
Body: []byte(`{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`),
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||||
|
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||||
|
statsSvc.EXPECT().Record(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||||
|
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||||
|
|
||||||
|
body := []byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body))
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "p1/gpt-4")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyHandler_RawStreamPassthrough_PreservesSSEFrames(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||||
|
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||||
|
providerSvc.EXPECT().List().Return([]domain.Provider{
|
||||||
|
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
|
||||||
|
}, nil)
|
||||||
|
client := mocks.NewMockProviderClient(ctrl)
|
||||||
|
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||||
|
assert.Contains(t, string(spec.Body), `"model":"gpt-4"`)
|
||||||
|
ch := make(chan provider.StreamEvent, 3)
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
ch <- provider.StreamEvent{Data: []byte("data: {\"model\":\"gpt-4\",\"choices\":[]}\n\n")}
|
||||||
|
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
|
||||||
|
ch <- provider.StreamEvent{Done: true}
|
||||||
|
}()
|
||||||
|
return &provider.StreamResponse{StatusCode: http.StatusOK, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
|
||||||
|
})
|
||||||
|
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||||
|
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||||
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "data: {\"model\":\"gpt-4\",\"choices\":[]}\n\ndata: [DONE]\n\n", w.Body.String())
|
||||||
|
}
|
||||||
|
|||||||
223
backend/internal/handler/settings_handler.go
Normal file
223
backend/internal/handler/settings_handler.go
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"nex/backend/internal/config"
|
||||||
|
appErrors "nex/backend/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SettingsHandler struct {
|
||||||
|
runtimeCfg *config.Config
|
||||||
|
mode string
|
||||||
|
editable bool
|
||||||
|
configPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSettingsHandler(runtimeCfg *config.Config, mode string, editable bool, configPath string) *SettingsHandler {
|
||||||
|
return &SettingsHandler{
|
||||||
|
runtimeCfg: runtimeCfg,
|
||||||
|
mode: mode,
|
||||||
|
editable: editable,
|
||||||
|
configPath: configPath,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverConfigDTO struct {
|
||||||
|
Port int `json:"port"`
|
||||||
|
ReadTimeout string `json:"read_timeout"`
|
||||||
|
WriteTimeout string `json:"write_timeout"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type databaseConfigDTO struct {
|
||||||
|
Driver string `json:"driver"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
Host string `json:"host"`
|
||||||
|
Port int `json:"port"`
|
||||||
|
User string `json:"user"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
DBName string `json:"dbname"`
|
||||||
|
MaxIdleConns int `json:"max_idle_conns"`
|
||||||
|
MaxOpenConns int `json:"max_open_conns"`
|
||||||
|
ConnMaxLifetime string `json:"conn_max_lifetime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type logConfigDTO struct {
|
||||||
|
Level string `json:"level"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
MaxSize int `json:"max_size"`
|
||||||
|
MaxBackups int `json:"max_backups"`
|
||||||
|
MaxAge int `json:"max_age"`
|
||||||
|
Compress bool `json:"compress"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type startupSettingsDTO struct {
|
||||||
|
Server serverConfigDTO `json:"server"`
|
||||||
|
Database databaseConfigDTO `json:"database"`
|
||||||
|
Log logConfigDTO `json:"log"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type startupSettingsResponse struct {
|
||||||
|
Mode string `json:"mode"`
|
||||||
|
Editable bool `json:"editable"`
|
||||||
|
ConfigPath string `json:"config_path"`
|
||||||
|
RestartRequired bool `json:"restart_required"`
|
||||||
|
Config startupSettingsDTO `json:"config"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func configToDTO(cfg *config.Config) startupSettingsDTO {
|
||||||
|
return startupSettingsDTO{
|
||||||
|
Server: serverConfigDTO{
|
||||||
|
Port: cfg.Server.Port,
|
||||||
|
ReadTimeout: cfg.Server.ReadTimeout.String(),
|
||||||
|
WriteTimeout: cfg.Server.WriteTimeout.String(),
|
||||||
|
},
|
||||||
|
Database: databaseConfigDTO{
|
||||||
|
Driver: cfg.Database.Driver,
|
||||||
|
Path: cfg.Database.Path,
|
||||||
|
Host: cfg.Database.Host,
|
||||||
|
Port: cfg.Database.Port,
|
||||||
|
User: cfg.Database.User,
|
||||||
|
Password: cfg.Database.Password,
|
||||||
|
DBName: cfg.Database.DBName,
|
||||||
|
MaxIdleConns: cfg.Database.MaxIdleConns,
|
||||||
|
MaxOpenConns: cfg.Database.MaxOpenConns,
|
||||||
|
ConnMaxLifetime: cfg.Database.ConnMaxLifetime.String(),
|
||||||
|
},
|
||||||
|
Log: logConfigDTO{
|
||||||
|
Level: cfg.Log.Level,
|
||||||
|
Path: cfg.Log.Path,
|
||||||
|
MaxSize: cfg.Log.MaxSize,
|
||||||
|
MaxBackups: cfg.Log.MaxBackups,
|
||||||
|
MaxAge: cfg.Log.MaxAge,
|
||||||
|
Compress: cfg.Log.Compress,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func dtoToConfig(dto startupSettingsDTO) (*config.Config, error) {
|
||||||
|
readTimeout, err := time.ParseDuration(dto.Server.ReadTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, appErrors.WithMessage(appErrors.ErrInvalidRequest, "read_timeout 格式错误,例如 30s")
|
||||||
|
}
|
||||||
|
writeTimeout, err := time.ParseDuration(dto.Server.WriteTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, appErrors.WithMessage(appErrors.ErrInvalidRequest, "write_timeout 格式错误,例如 30s")
|
||||||
|
}
|
||||||
|
connMaxLifetime, err := time.ParseDuration(dto.Database.ConnMaxLifetime)
|
||||||
|
if err != nil {
|
||||||
|
return nil, appErrors.WithMessage(appErrors.ErrInvalidRequest, "conn_max_lifetime 格式错误,例如 1h")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Port: dto.Server.Port,
|
||||||
|
ReadTimeout: readTimeout,
|
||||||
|
WriteTimeout: writeTimeout,
|
||||||
|
},
|
||||||
|
Database: config.DatabaseConfig{
|
||||||
|
Driver: dto.Database.Driver,
|
||||||
|
Path: dto.Database.Path,
|
||||||
|
Host: dto.Database.Host,
|
||||||
|
Port: dto.Database.Port,
|
||||||
|
User: dto.Database.User,
|
||||||
|
Password: dto.Database.Password,
|
||||||
|
DBName: dto.Database.DBName,
|
||||||
|
MaxIdleConns: dto.Database.MaxIdleConns,
|
||||||
|
MaxOpenConns: dto.Database.MaxOpenConns,
|
||||||
|
ConnMaxLifetime: connMaxLifetime,
|
||||||
|
},
|
||||||
|
Log: config.LogConfig{
|
||||||
|
Level: dto.Log.Level,
|
||||||
|
Path: dto.Log.Path,
|
||||||
|
MaxSize: dto.Log.MaxSize,
|
||||||
|
MaxBackups: dto.Log.MaxBackups,
|
||||||
|
MaxAge: dto.Log.MaxAge,
|
||||||
|
Compress: dto.Log.Compress,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SettingsHandler) GetStartupSettings(c *gin.Context) {
|
||||||
|
var cfg *config.Config
|
||||||
|
var configPath string
|
||||||
|
|
||||||
|
if h.mode == "desktop" {
|
||||||
|
desktopCfg, err := config.LoadDesktopConfigAtPath(h.configPath)
|
||||||
|
if err != nil {
|
||||||
|
writeError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg = desktopCfg
|
||||||
|
configPath = h.configPath
|
||||||
|
} else {
|
||||||
|
cfg = h.runtimeCfg
|
||||||
|
configPath = h.configPath
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, startupSettingsResponse{
|
||||||
|
Mode: h.mode,
|
||||||
|
Editable: h.editable,
|
||||||
|
ConfigPath: configPath,
|
||||||
|
RestartRequired: h.editable,
|
||||||
|
Config: configToDTO(cfg),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SettingsHandler) SaveStartupSettings(c *gin.Context) {
|
||||||
|
if !h.editable {
|
||||||
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
|
"error": "server 模式下不允许保存启动参数",
|
||||||
|
"code": "forbidden",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req struct {
|
||||||
|
Config startupSettingsDTO `json:"config"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"error": "无效的请求格式",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := dtoToConfig(req.Config)
|
||||||
|
if err != nil {
|
||||||
|
writeError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cfg.Validate(); err != nil {
|
||||||
|
writeError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := config.SaveConfigToPath(cfg, h.configPath); err != nil {
|
||||||
|
if errors.Is(err, appErrors.ErrInvalidRequest) {
|
||||||
|
writeError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeError(c, appErrors.Wrap(appErrors.ErrInternal, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
savedCfg, err := config.LoadDesktopConfigAtPath(h.configPath)
|
||||||
|
if err != nil {
|
||||||
|
writeError(c, appErrors.Wrap(appErrors.ErrInternal, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, startupSettingsResponse{
|
||||||
|
Mode: h.mode,
|
||||||
|
Editable: h.editable,
|
||||||
|
ConfigPath: h.configPath,
|
||||||
|
RestartRequired: true,
|
||||||
|
Config: configToDTO(savedCfg),
|
||||||
|
})
|
||||||
|
}
|
||||||
510
backend/internal/handler/settings_handler_test.go
Normal file
510
backend/internal/handler/settings_handler_test.go
Normal file
@@ -0,0 +1,510 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
|
"nex/backend/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestConfig(t *testing.T) (*config.Config, string) {
|
||||||
|
t.Helper()
|
||||||
|
dir := t.TempDir()
|
||||||
|
configPath := filepath.Join(dir, "config.yaml")
|
||||||
|
|
||||||
|
cfg := config.DefaultConfig()
|
||||||
|
cfg.Database.Path = filepath.Join(dir, "test.db")
|
||||||
|
cfg.Log.Path = filepath.Join(dir, "log")
|
||||||
|
data, err := yaml.Marshal(cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, os.WriteFile(configPath, data, 0o600))
|
||||||
|
|
||||||
|
return cfg, configPath
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsHandler_GetStartupSettings_Desktop(t *testing.T) {
|
||||||
|
cfg, configPath := createTestConfig(t)
|
||||||
|
h := NewSettingsHandler(cfg, "desktop", true, configPath)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("GET", "/api/settings/startup", nil)
|
||||||
|
|
||||||
|
h.GetStartupSettings(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
var resp startupSettingsResponse
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
assert.Equal(t, "desktop", resp.Mode)
|
||||||
|
assert.True(t, resp.Editable)
|
||||||
|
assert.True(t, resp.RestartRequired)
|
||||||
|
assert.Equal(t, configPath, resp.ConfigPath)
|
||||||
|
assert.Equal(t, cfg.Server.Port, resp.Config.Server.Port)
|
||||||
|
assert.Equal(t, "30s", resp.Config.Server.ReadTimeout)
|
||||||
|
assert.Equal(t, cfg.Database.Driver, resp.Config.Database.Driver)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsHandler_GetStartupSettings_Server(t *testing.T) {
|
||||||
|
cfg, configPath := createTestConfig(t)
|
||||||
|
h := NewSettingsHandler(cfg, "server", false, configPath)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("GET", "/api/settings/startup", nil)
|
||||||
|
|
||||||
|
h.GetStartupSettings(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
var resp startupSettingsResponse
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
assert.Equal(t, "server", resp.Mode)
|
||||||
|
assert.False(t, resp.Editable)
|
||||||
|
assert.False(t, resp.RestartRequired)
|
||||||
|
assert.Equal(t, configPath, resp.ConfigPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsHandler_SaveStartupSettings_Desktop(t *testing.T) {
|
||||||
|
cfg, configPath := createTestConfig(t)
|
||||||
|
h := NewSettingsHandler(cfg, "desktop", true, configPath)
|
||||||
|
|
||||||
|
newPort := 9999
|
||||||
|
body, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"config": map[string]interface{}{
|
||||||
|
"server": map[string]interface{}{
|
||||||
|
"port": newPort,
|
||||||
|
"read_timeout": "30s",
|
||||||
|
"write_timeout": "30s",
|
||||||
|
},
|
||||||
|
"database": map[string]interface{}{
|
||||||
|
"driver": "sqlite",
|
||||||
|
"path": filepath.Join(t.TempDir(), "new.db"),
|
||||||
|
"port": 3306,
|
||||||
|
"dbname": "nex",
|
||||||
|
"max_idle_conns": 10,
|
||||||
|
"max_open_conns": 100,
|
||||||
|
"conn_max_lifetime": "1h",
|
||||||
|
},
|
||||||
|
"log": map[string]interface{}{
|
||||||
|
"level": "info",
|
||||||
|
"path": filepath.Join(t.TempDir(), "log"),
|
||||||
|
"max_size": 100,
|
||||||
|
"max_backups": 10,
|
||||||
|
"max_age": 30,
|
||||||
|
"compress": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("PUT", "/api/settings/startup", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.SaveStartupSettings(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
var resp startupSettingsResponse
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
assert.Equal(t, newPort, resp.Config.Server.Port)
|
||||||
|
assert.True(t, resp.Editable)
|
||||||
|
assert.True(t, resp.RestartRequired)
|
||||||
|
|
||||||
|
savedCfg, err := config.LoadDesktopConfigAtPath(configPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, newPort, savedCfg.Server.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsHandler_SaveStartupSettings_Server_Forbidden(t *testing.T) {
|
||||||
|
cfg, configPath := createTestConfig(t)
|
||||||
|
h := NewSettingsHandler(cfg, "server", false, configPath)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"config": map[string]interface{}{
|
||||||
|
"server": map[string]interface{}{
|
||||||
|
"port": 9999,
|
||||||
|
"read_timeout": "30s",
|
||||||
|
"write_timeout": "30s",
|
||||||
|
},
|
||||||
|
"database": map[string]interface{}{
|
||||||
|
"driver": "sqlite",
|
||||||
|
"path": "/tmp/test.db",
|
||||||
|
"port": 3306,
|
||||||
|
"dbname": "nex",
|
||||||
|
"max_idle_conns": 10,
|
||||||
|
"max_open_conns": 100,
|
||||||
|
"conn_max_lifetime": "1h",
|
||||||
|
},
|
||||||
|
"log": map[string]interface{}{
|
||||||
|
"level": "info",
|
||||||
|
"path": "/tmp/log",
|
||||||
|
"max_size": 100,
|
||||||
|
"max_backups": 10,
|
||||||
|
"max_age": 30,
|
||||||
|
"compress": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("PUT", "/api/settings/startup", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.SaveStartupSettings(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 403, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "不允许保存")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsHandler_SaveStartupSettings_Desktop_InvalidConfig(t *testing.T) {
|
||||||
|
cfg, configPath := createTestConfig(t)
|
||||||
|
h := NewSettingsHandler(cfg, "desktop", true, configPath)
|
||||||
|
|
||||||
|
originalData, err := os.ReadFile(configPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"config": map[string]interface{}{
|
||||||
|
"server": map[string]interface{}{
|
||||||
|
"port": 0,
|
||||||
|
"read_timeout": "30s",
|
||||||
|
"write_timeout": "30s",
|
||||||
|
},
|
||||||
|
"database": map[string]interface{}{
|
||||||
|
"driver": "sqlite",
|
||||||
|
"path": "/tmp/test.db",
|
||||||
|
"port": 3306,
|
||||||
|
"dbname": "nex",
|
||||||
|
"max_idle_conns": 10,
|
||||||
|
"max_open_conns": 100,
|
||||||
|
"conn_max_lifetime": "1h",
|
||||||
|
},
|
||||||
|
"log": map[string]interface{}{
|
||||||
|
"level": "info",
|
||||||
|
"path": "/tmp/log",
|
||||||
|
"max_size": 100,
|
||||||
|
"max_backups": 10,
|
||||||
|
"max_age": 30,
|
||||||
|
"compress": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("PUT", "/api/settings/startup", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.SaveStartupSettings(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 400, w.Code)
|
||||||
|
|
||||||
|
currentData, err := os.ReadFile(configPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, originalData, currentData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsHandler_SaveStartupSettings_Desktop_InvalidDuration(t *testing.T) {
|
||||||
|
cfg, configPath := createTestConfig(t)
|
||||||
|
h := NewSettingsHandler(cfg, "desktop", true, configPath)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"config": map[string]interface{}{
|
||||||
|
"server": map[string]interface{}{
|
||||||
|
"port": 9826,
|
||||||
|
"read_timeout": "not-a-duration",
|
||||||
|
"write_timeout": "30s",
|
||||||
|
},
|
||||||
|
"database": map[string]interface{}{
|
||||||
|
"driver": "sqlite",
|
||||||
|
"path": "/tmp/test.db",
|
||||||
|
"port": 3306,
|
||||||
|
"dbname": "nex",
|
||||||
|
"max_idle_conns": 10,
|
||||||
|
"max_open_conns": 100,
|
||||||
|
"conn_max_lifetime": "1h",
|
||||||
|
},
|
||||||
|
"log": map[string]interface{}{
|
||||||
|
"level": "info",
|
||||||
|
"path": "/tmp/log",
|
||||||
|
"max_size": 100,
|
||||||
|
"max_backups": 10,
|
||||||
|
"max_age": 30,
|
||||||
|
"compress": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("PUT", "/api/settings/startup", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.SaveStartupSettings(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 400, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "read_timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsHandler_SaveStartupSettings_Desktop_CreatesConfigFile(t *testing.T) {
|
||||||
|
cfg := config.DefaultConfig()
|
||||||
|
dir := t.TempDir()
|
||||||
|
configPath := filepath.Join(dir, "nex", "config.yaml")
|
||||||
|
|
||||||
|
_, err := os.Stat(configPath)
|
||||||
|
assert.True(t, os.IsNotExist(err))
|
||||||
|
|
||||||
|
h := NewSettingsHandler(cfg, "desktop", true, configPath)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"config": map[string]interface{}{
|
||||||
|
"server": map[string]interface{}{
|
||||||
|
"port": 9826,
|
||||||
|
"read_timeout": "30s",
|
||||||
|
"write_timeout": "30s",
|
||||||
|
},
|
||||||
|
"database": map[string]interface{}{
|
||||||
|
"driver": "sqlite",
|
||||||
|
"path": filepath.Join(dir, "test.db"),
|
||||||
|
"port": 3306,
|
||||||
|
"dbname": "nex",
|
||||||
|
"max_idle_conns": 10,
|
||||||
|
"max_open_conns": 100,
|
||||||
|
"conn_max_lifetime": "1h",
|
||||||
|
},
|
||||||
|
"log": map[string]interface{}{
|
||||||
|
"level": "info",
|
||||||
|
"path": filepath.Join(dir, "log"),
|
||||||
|
"max_size": 100,
|
||||||
|
"max_backups": 10,
|
||||||
|
"max_age": 30,
|
||||||
|
"compress": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("PUT", "/api/settings/startup", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.SaveStartupSettings(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
_, err = os.Stat(configPath)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsHandler_SaveStartupSettings_Desktop_PasswordIncluded(t *testing.T) {
|
||||||
|
cfg, configPath := createTestConfig(t)
|
||||||
|
h := NewSettingsHandler(cfg, "desktop", true, configPath)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"config": map[string]interface{}{
|
||||||
|
"server": map[string]interface{}{
|
||||||
|
"port": 9826,
|
||||||
|
"read_timeout": "30s",
|
||||||
|
"write_timeout": "30s",
|
||||||
|
},
|
||||||
|
"database": map[string]interface{}{
|
||||||
|
"driver": "mysql",
|
||||||
|
"host": "localhost",
|
||||||
|
"port": 3306,
|
||||||
|
"user": "root",
|
||||||
|
"password": "secret123",
|
||||||
|
"dbname": "nex",
|
||||||
|
"path": "",
|
||||||
|
"max_idle_conns": 10,
|
||||||
|
"max_open_conns": 100,
|
||||||
|
"conn_max_lifetime": "1h",
|
||||||
|
},
|
||||||
|
"log": map[string]interface{}{
|
||||||
|
"level": "info",
|
||||||
|
"path": filepath.Join(t.TempDir(), "log"),
|
||||||
|
"max_size": 100,
|
||||||
|
"max_backups": 10,
|
||||||
|
"max_age": 30,
|
||||||
|
"compress": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("PUT", "/api/settings/startup", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.SaveStartupSettings(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
var resp startupSettingsResponse
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
assert.Equal(t, "secret123", resp.Config.Database.Password)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsHandler_GetStartupSettings_DesktopReadsConfigFile(t *testing.T) {
|
||||||
|
cfg, configPath := createTestConfig(t)
|
||||||
|
|
||||||
|
savedCfg := config.DefaultConfig()
|
||||||
|
savedCfg.Server.Port = 5555
|
||||||
|
data, err := yaml.Marshal(savedCfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, os.WriteFile(configPath, data, 0o600))
|
||||||
|
|
||||||
|
h := NewSettingsHandler(cfg, "desktop", true, configPath)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("GET", "/api/settings/startup", nil)
|
||||||
|
|
||||||
|
h.GetStartupSettings(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
var resp startupSettingsResponse
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
assert.Equal(t, 5555, resp.Config.Server.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsHandler_GetStartupSettings_ServerReturnsRuntime(t *testing.T) {
|
||||||
|
cfg, configPath := createTestConfig(t)
|
||||||
|
cfg.Server.Port = 7777
|
||||||
|
|
||||||
|
h := NewSettingsHandler(cfg, "server", false, configPath)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("GET", "/api/settings/startup", nil)
|
||||||
|
|
||||||
|
h.GetStartupSettings(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
var resp startupSettingsResponse
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
assert.Equal(t, 7777, resp.Config.Server.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsHandler_SaveStartupSettings_InvalidJSON(t *testing.T) {
|
||||||
|
cfg, configPath := createTestConfig(t)
|
||||||
|
h := NewSettingsHandler(cfg, "desktop", true, configPath)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("PUT", "/api/settings/startup", bytes.NewReader([]byte("{invalid")))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.SaveStartupSettings(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 400, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsHandler_GetStartupSettings_DurationNormalization(t *testing.T) {
|
||||||
|
cfg, configPath := createTestConfig(t)
|
||||||
|
|
||||||
|
yamlContent := `
|
||||||
|
server:
|
||||||
|
port: 9826
|
||||||
|
read_timeout: 30s
|
||||||
|
write_timeout: 1m
|
||||||
|
database:
|
||||||
|
driver: sqlite
|
||||||
|
path: ` + cfg.Database.Path + `
|
||||||
|
max_idle_conns: 10
|
||||||
|
max_open_conns: 100
|
||||||
|
conn_max_lifetime: 30m
|
||||||
|
log:
|
||||||
|
level: info
|
||||||
|
path: ` + cfg.Log.Path + `
|
||||||
|
max_size: 100
|
||||||
|
max_backups: 10
|
||||||
|
max_age: 30
|
||||||
|
compress: true
|
||||||
|
`
|
||||||
|
require.NoError(t, os.WriteFile(configPath, []byte(yamlContent), 0o600))
|
||||||
|
|
||||||
|
h := NewSettingsHandler(cfg, "desktop", true, configPath)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("GET", "/api/settings/startup", nil)
|
||||||
|
|
||||||
|
h.GetStartupSettings(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
var resp startupSettingsResponse
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
assert.Equal(t, "1m0s", resp.Config.Server.WriteTimeout)
|
||||||
|
assert.Equal(t, "30m0s", resp.Config.Database.ConnMaxLifetime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingsHandler_SaveStartupSettings_StandardDurationRoundTrip(t *testing.T) {
|
||||||
|
cfg, configPath := createTestConfig(t)
|
||||||
|
h := NewSettingsHandler(cfg, "desktop", true, configPath)
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
body, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"config": map[string]interface{}{
|
||||||
|
"server": map[string]interface{}{
|
||||||
|
"port": 9826,
|
||||||
|
"read_timeout": "30s",
|
||||||
|
"write_timeout": "1m0s",
|
||||||
|
},
|
||||||
|
"database": map[string]interface{}{
|
||||||
|
"driver": "sqlite",
|
||||||
|
"path": filepath.Join(tmpDir, "test.db"),
|
||||||
|
"port": 3306,
|
||||||
|
"dbname": "nex",
|
||||||
|
"max_idle_conns": 10,
|
||||||
|
"max_open_conns": 100,
|
||||||
|
"conn_max_lifetime": "1h0m0s",
|
||||||
|
},
|
||||||
|
"log": map[string]interface{}{
|
||||||
|
"level": "info",
|
||||||
|
"path": filepath.Join(tmpDir, "log"),
|
||||||
|
"max_size": 100,
|
||||||
|
"max_backups": 10,
|
||||||
|
"max_age": 30,
|
||||||
|
"compress": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("PUT", "/api/settings/startup", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.SaveStartupSettings(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
var resp startupSettingsResponse
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
assert.Equal(t, "1m0s", resp.Config.Server.WriteTimeout)
|
||||||
|
assert.Equal(t, "1h0m0s", resp.Config.Database.ConnMaxLifetime)
|
||||||
|
|
||||||
|
savedCfg, err := config.LoadDesktopConfigAtPath(configPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1*time.Hour, savedCfg.Database.ConnMaxLifetime)
|
||||||
|
}
|
||||||
@@ -5,9 +5,9 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
|
|
||||||
"nex/backend/internal/service"
|
"nex/backend/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StatsHandler 统计处理器
|
// StatsHandler 统计处理器
|
||||||
|
|||||||
26
backend/internal/handler/version_handler.go
Normal file
26
backend/internal/handler/version_handler.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"nex/backend/pkg/buildinfo"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VersionHandler 提供后端构建版本信息。
|
||||||
|
type VersionHandler struct{}
|
||||||
|
|
||||||
|
// NewVersionHandler 创建版本信息处理器。
|
||||||
|
func NewVersionHandler() *VersionHandler {
|
||||||
|
return &VersionHandler{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVersion 返回构建注入的版本元数据。
|
||||||
|
func (h *VersionHandler) GetVersion(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"version": buildinfo.Version(),
|
||||||
|
"commit": buildinfo.Commit(),
|
||||||
|
"build_time": buildinfo.BuildTime(),
|
||||||
|
})
|
||||||
|
}
|
||||||
31
backend/internal/handler/version_handler_test.go
Normal file
31
backend/internal/handler/version_handler_test.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestVersionHandler_GetVersion(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := NewVersionHandler()
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/api/version", nil)
|
||||||
|
|
||||||
|
h.GetVersion(c)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var result map[string]string
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||||
|
assert.Equal(t, "dev", result["version"])
|
||||||
|
assert.Equal(t, "unknown", result["commit"])
|
||||||
|
assert.Equal(t, "unknown", result["build_time"])
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -15,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"nex/backend/internal/conversion"
|
"nex/backend/internal/conversion"
|
||||||
pkgErrors "nex/backend/pkg/errors"
|
pkgErrors "nex/backend/pkg/errors"
|
||||||
|
pkglogger "nex/backend/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StreamConfig 流式处理配置
|
// StreamConfig 流式处理配置
|
||||||
@@ -42,6 +44,14 @@ type StreamEvent struct {
|
|||||||
Done bool
|
Done bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StreamResponse 表示上游流式 HTTP 响应。
|
||||||
|
type StreamResponse struct {
|
||||||
|
StatusCode int
|
||||||
|
Headers map[string]string
|
||||||
|
Body []byte
|
||||||
|
Events <-chan StreamEvent
|
||||||
|
}
|
||||||
|
|
||||||
// Client 协议无关的供应商客户端
|
// Client 协议无关的供应商客户端
|
||||||
type Client struct {
|
type Client struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
@@ -50,19 +60,20 @@ type Client struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ProviderClient 供应商客户端接口
|
// ProviderClient 供应商客户端接口
|
||||||
|
//
|
||||||
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
|
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
|
||||||
type ProviderClient interface {
|
type ProviderClient interface {
|
||||||
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
|
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
|
||||||
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error)
|
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient 创建供应商客户端
|
// NewClient 创建供应商客户端
|
||||||
func NewClient() *Client {
|
func NewClient(logger *zap.Logger) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
httpClient: &http.Client{
|
httpClient: &http.Client{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
},
|
},
|
||||||
logger: zap.L(),
|
logger: pkglogger.WithModule(logger, "provider.client"),
|
||||||
streamCfg: DefaultStreamConfig(),
|
streamCfg: DefaultStreamConfig(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -114,7 +125,7 @@ func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*co
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendStream 发送流式请求
|
// SendStream 发送流式请求
|
||||||
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) {
|
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error) {
|
||||||
var bodyReader io.Reader
|
var bodyReader io.Reader
|
||||||
if len(spec.Body) > 0 {
|
if len(spec.Body) > 0 {
|
||||||
bodyReader = bytes.NewReader(spec.Body)
|
bodyReader = bytes.NewReader(spec.Body)
|
||||||
@@ -137,20 +148,29 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
|
|||||||
return nil, pkgErrors.ErrRequestSend.WithCause(err)
|
return nil, pkgErrors.ErrRequestSend.WithCause(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
respHeaders := extractResponseHeaders(resp.Header)
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
cancel()
|
cancel()
|
||||||
errBody, _ := io.ReadAll(resp.Body)
|
errBody, readErr := io.ReadAll(resp.Body)
|
||||||
if len(errBody) > 0 {
|
if readErr != nil {
|
||||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody))
|
return nil, pkgErrors.ErrResponseRead.WithCause(readErr)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
|
return &StreamResponse{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Headers: respHeaders,
|
||||||
|
Body: errBody,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize)
|
eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize)
|
||||||
go c.readStream(streamCtx, cancel, resp.Body, eventChan)
|
go c.readStream(streamCtx, cancel, resp.Body, eventChan)
|
||||||
|
|
||||||
return eventChan, nil
|
return &StreamResponse{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Headers: respHeaders,
|
||||||
|
Events: eventChan,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// readStream 读取 SSE 流
|
// readStream 读取 SSE 流
|
||||||
@@ -183,10 +203,10 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
if isNetworkError(err) {
|
if isNetworkError(err) {
|
||||||
c.logger.Error("流网络错误", zap.String("error", err.Error()))
|
c.logger.Error("流网络错误", zap.Error(err))
|
||||||
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
|
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
|
||||||
} else {
|
} else {
|
||||||
c.logger.Error("流读取错误", zap.String("error", err.Error()))
|
c.logger.Error("流读取错误", zap.Error(err))
|
||||||
eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)}
|
eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -203,15 +223,17 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
idx := bytes.Index(dataBuf, []byte("\n\n"))
|
idx, sepLen := findSSEFrameSeparator(dataBuf)
|
||||||
if idx == -1 {
|
if idx == -1 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
rawEvent := dataBuf[:idx]
|
frameEnd := idx + sepLen
|
||||||
dataBuf = dataBuf[idx+2:]
|
rawEvent := append([]byte(nil), dataBuf[:frameEnd]...)
|
||||||
|
dataBuf = dataBuf[frameEnd:]
|
||||||
|
|
||||||
if bytes.Contains(rawEvent, []byte("data: [DONE]")) {
|
if isSSEDoneFrame(rawEvent) {
|
||||||
|
eventChan <- StreamEvent{Data: rawEvent}
|
||||||
eventChan <- StreamEvent{Done: true}
|
eventChan <- StreamEvent{Done: true}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -220,11 +242,66 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
|
if len(dataBuf) > 0 {
|
||||||
|
eventChan <- StreamEvent{Data: dataBuf}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isSSEDoneFrame(frame []byte) bool {
|
||||||
|
payload, ok := sseFrameDataPayload(frame)
|
||||||
|
return ok && strings.TrimSpace(payload) == "[DONE]"
|
||||||
|
}
|
||||||
|
|
||||||
|
func sseFrameDataPayload(frame []byte) (string, bool) {
|
||||||
|
text := strings.TrimRight(string(frame), "\r\n")
|
||||||
|
lines := strings.Split(text, "\n")
|
||||||
|
var dataLines []string
|
||||||
|
for _, line := range lines {
|
||||||
|
line = strings.TrimRight(line, "\r")
|
||||||
|
if strings.HasPrefix(line, "data:") {
|
||||||
|
value := strings.TrimPrefix(line, "data:")
|
||||||
|
if strings.HasPrefix(value, " ") {
|
||||||
|
value = value[1:]
|
||||||
|
}
|
||||||
|
dataLines = append(dataLines, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(dataLines) == 0 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return strings.Join(dataLines, "\n"), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractResponseHeaders(header http.Header) map[string]string {
|
||||||
|
respHeaders := make(map[string]string)
|
||||||
|
for k, vs := range header {
|
||||||
|
if len(vs) > 0 {
|
||||||
|
respHeaders[k] = vs[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return respHeaders
|
||||||
|
}
|
||||||
|
|
||||||
|
func findSSEFrameSeparator(data []byte) (int, int) {
|
||||||
|
lf := bytes.Index(data, []byte("\n\n"))
|
||||||
|
crlf := bytes.Index(data, []byte("\r\n\r\n"))
|
||||||
|
switch {
|
||||||
|
case lf < 0 && crlf < 0:
|
||||||
|
return -1, 0
|
||||||
|
case lf < 0:
|
||||||
|
return crlf, 4
|
||||||
|
case crlf < 0:
|
||||||
|
return lf, 2
|
||||||
|
case crlf <= lf:
|
||||||
|
return crlf, 4
|
||||||
|
default:
|
||||||
|
return lf, 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// isNetworkError 判断是否为网络相关错误
|
// isNetworkError 判断是否为网络相关错误
|
||||||
func isNetworkError(err error) bool {
|
func isNetworkError(err error) bool {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user