1
0

chore: 合并 dev-code-backend-format 到 master

This commit is contained in:
2026-04-24 18:20:12 +08:00
91 changed files with 1322 additions and 824 deletions

1
.gitignore vendored
View File

@@ -401,6 +401,7 @@ cython_debug/
# Custom # Custom
.claude .claude
.opencode .opencode
.codex
openspec/changes/archive openspec/changes/archive
temp temp
.agents .agents

View File

@@ -1,11 +1,11 @@
.PHONY: all dev build test lint clean \ .PHONY: all dev build test lint clean \
backend-build backend-run backend-dev backend-test backend-test-unit backend-test-integration backend-test-coverage \ backend-build backend-run backend-dev backend-test backend-test-all backend-test-unit backend-test-integration backend-test-coverage \
backend-lint backend-clean backend-deps backend-generate \ backend-lint backend-clean backend-deps backend-generate \
backend-db-up backend-db-down backend-db-status backend-db-create \ backend-db-up backend-db-down backend-db-status backend-db-create \
test-mysql-up test-mysql-down test-mysql test-mysql-quick \ test-mysql-up test-mysql-down test-mysql test-mysql-quick \
frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint frontend-clean \ frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint frontend-clean \
desktop-build desktop-build-mac desktop-build-win desktop-build-linux \ desktop-build desktop-build-mac desktop-build-win desktop-build-linux \
desktop-dev desktop-package-mac desktop-package-win desktop-package-linux desktop-clean \ desktop-dev desktop-test desktop-package-mac desktop-package-win desktop-package-linux desktop-clean \
desktop-prepare-frontend desktop-prepare-embedfs desktop-prepare-frontend desktop-prepare-embedfs
# ============================================ # ============================================
@@ -19,7 +19,7 @@ dev:
build: backend-build frontend-build build: backend-build frontend-build
@echo "✅ Build complete" @echo "✅ Build complete"
test: backend-test frontend-test test: backend-test desktop-test frontend-test
@echo "✅ All tests passed" @echo "✅ All tests passed"
lint: backend-lint frontend-lint lint: backend-lint frontend-lint
@@ -41,6 +41,9 @@ backend-dev:
cd backend && go run ./cmd/server cd backend && go run ./cmd/server
backend-test: backend-test:
cd backend && go test ./internal/... ./pkg/... ./tests/... ./cmd/server/... -v
backend-test-all:
cd backend && go test ./... -v cd backend && go test ./... -v
backend-test-unit: backend-test-unit:
@@ -179,6 +182,9 @@ desktop-dev: desktop-prepare-frontend desktop-prepare-embedfs
@echo "🖥️ Starting desktop app in dev mode..." @echo "🖥️ Starting desktop app in dev mode..."
cd backend && go run ./cmd/desktop cd backend && go run ./cmd/desktop
desktop-test:
cd backend && go test ./cmd/desktop/... -v
desktop-package-mac: desktop-package-mac:
./scripts/build/package-macos.sh ./scripts/build/package-macos.sh

View File

@@ -294,6 +294,9 @@ make frontend-test-coverage # 前端覆盖率
## 开发 ## 开发
```bash ```bash
# 首次克隆后安装 Git hooks
lefthook install
# 顶层便捷命令 # 顶层便捷命令
make dev # 启动开发环境(并行启动后端和前端) make dev # 启动开发环境(并行启动后端和前端)
make build # 构建所有产物 make build # 构建所有产物

91
backend/.golangci.yml Normal file
View File

@@ -0,0 +1,91 @@
run:
timeout: 5m
tests: true
linters:
disable-all: true
enable:
- forbidigo
- errorlint
- errcheck
- staticcheck
- revive
- gocritic
- gosec
- bodyclose
- noctx
- nilerr
- goimports
- gocyclo
linters-settings:
errcheck:
check-blank: true
check-type-assertions: true
exclude-functions:
- fmt.Fprintf
forbidigo:
analyze-types: true
forbid:
- p: '^fmt\.Print.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^fmt\.Fprint.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^log\.(Print|Println|Printf|Fatal|Fatalln|Fatalf|Panic|Panicln|Panicf)$'
msg: 使用 zap logger不要使用标准库 log
- p: '^zap\.L$'
msg: 通过依赖注入传递 *zap.Logger不要使用全局 logger
- p: '^zap\.S$'
msg: 不使用 Sugar logger
revive:
rules:
- name: exported
- name: var-naming
- name: indent-error-flow
- name: error-strings
- name: error-return
- name: blank-imports
- name: context-as-argument
- name: unexported-return
goimports:
local-prefixes: nex/backend
gocyclo:
min-complexity: 10
issues:
exclude-dirs:
- tests/mocks
exclude-generated: true
exclude-rules:
- path: '(_test\.go|tests/)'
linters:
- forbidigo
- path: '(_test\.go|tests/)'
linters:
- errcheck
source: '(^\s*_\s*=|,\s*_)'
- path: 'tests/integration/e2e_conversion_test\.go'
linters:
- errcheck
- path: '(_test\.go|tests/)'
linters:
- revive
text: '^exported:'
- path: '(_test\.go|tests/)'
linters:
- gosec
text: 'G(101|401|501)'
- path: '(_test\.go|tests/)'
linters:
- gocyclo
text: 'cyclomatic complexity (1[1-9]|20) of .* is high \(> 10\)'
- linters:
- revive
text: '(that stutters|BuildUrl should be BuildURL|ConvertHttpRequest should be ConvertHTTPRequest|ConvertHttpResponse should be ConvertHTTPResponse)'
- path: 'internal/conversion/.*\.go'
linters:
- gocyclo
- gocritic
- path: '(internal/provider/client\.go|internal/service/model_service_impl\.go|internal/service/stats_buffer\.go|internal/handler/proxy_handler\.go|cmd/(desktop|server)/main\.go)'
linters:
- gocyclo

View File

@@ -609,6 +609,7 @@ err := v.Validate(myStruct)
- **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节 - **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接 - **字符串拼接**:使用 `strings.Join`,不手写循环拼接
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配 - **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配lint 强约束errorlint 禁止 `err ==` 直接比较和 `err.(*T)` 直接断言)
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()` - **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
- **日志 error 字段**:使用 `zap.Error(err)`,不使用 `zap.String("error", err.Error())` 手工字符串化
- **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片 - **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片

View File

@@ -6,19 +6,25 @@ import (
"fmt" "fmt"
"os/exec" "os/exec"
"strings" "strings"
"go.uber.org/zap"
) )
func showError(title, message string) { func showError(title, message string) {
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`, script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`,
escapeAppleScript(message), escapeAppleScript(title)) escapeAppleScript(message), escapeAppleScript(title))
exec.Command("osascript", "-e", script).Run() if err := exec.Command("osascript", "-e", script).Run(); err != nil {
dialogLogger().Warn("显示错误对话框失败", zap.Error(err))
}
} }
func showAbout() { func showAbout() {
message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway" message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway"
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "关于 Nex Gateway"`, script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "关于 Nex Gateway"`,
escapeAppleScript(message)) escapeAppleScript(message))
exec.Command("osascript", "-e", script).Run() if err := exec.Command("osascript", "-e", script).Run(); err != nil {
dialogLogger().Warn("显示关于对话框失败", zap.Error(err))
}
} }
func escapeAppleScript(s string) string { func escapeAppleScript(s string) string {

View File

@@ -4,7 +4,6 @@ package main
import ( import (
"fmt" "fmt"
"os"
"os/exec" "os/exec"
"sync" "sync"
) )
@@ -63,7 +62,7 @@ func showError(title, message string) {
exec.Command("xmessage", "-center", exec.Command("xmessage", "-center",
fmt.Sprintf("%s: %s", title, message)).Run() fmt.Sprintf("%s: %s", title, message)).Run()
default: default:
fmt.Fprintf(os.Stderr, "错误: %s: %s\n", title, message) dialogLogger().Error("无法显示错误对话框")
} }
} }
@@ -83,6 +82,6 @@ func showAbout() {
exec.Command("xmessage", "-center", exec.Command("xmessage", "-center",
fmt.Sprintf("关于 Nex Gateway: %s", message)).Run() fmt.Sprintf("关于 Nex Gateway: %s", message)).Run()
default: default:
fmt.Fprintf(os.Stderr, "关于 Nex Gateway: %s\n", message) dialogLogger().Info("关于 Nex Gateway")
} }
} }

View File

@@ -0,0 +1,15 @@
package main
import (
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
)
func dialogLogger() *zap.Logger {
if zapLogger != nil {
return zapLogger
}
return pkgLogger.NewMinimal()
}

View File

@@ -13,10 +13,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/getlantern/systray" "nex/embedfs"
"github.com/gin-gonic/gin"
"github.com/gofrs/flock"
"go.uber.org/zap"
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
@@ -28,9 +25,13 @@ import (
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/internal/service" "nex/backend/internal/service"
pkgLogger "nex/backend/pkg/logger"
"nex/embedfs" "github.com/getlantern/systray"
"github.com/gin-gonic/gin"
"github.com/gofrs/flock"
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
) )
var ( var (
@@ -51,12 +52,16 @@ func main() {
showError("Nex Gateway", "已有 Nex 实例运行") showError("Nex Gateway", "已有 Nex 实例运行")
os.Exit(1) os.Exit(1)
} }
defer singleLock.Unlock() defer func() {
if err := singleLock.Unlock(); err != nil {
minimalLogger.Warn("释放实例锁失败", zap.Error(err))
}
}()
if err := checkPortAvailable(port); err != nil { if err := checkPortAvailable(port); err != nil {
minimalLogger.Error("端口不可用", zap.Error(err)) minimalLogger.Error("端口不可用", zap.Error(err))
showError("Nex Gateway", err.Error()) showError("Nex Gateway", err.Error())
os.Exit(1) return
} }
cfg, err := config.LoadConfig() cfg, err := config.LoadConfig()
@@ -75,7 +80,11 @@ func main() {
if err != nil { if err != nil {
minimalLogger.Fatal("初始化日志失败", zap.Error(err)) minimalLogger.Fatal("初始化日志失败", zap.Error(err))
} }
defer zapLogger.Sync() defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
cfg.PrintSummary(zapLogger) cfg.PrintSummary(zapLogger)
@@ -144,14 +153,14 @@ func main() {
go func() { go func() {
zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr)) zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr))
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error())) zapLogger.Fatal("服务器启动失败", zap.Error(err))
} }
}() }()
go func() { go func() {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil { if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
zapLogger.Warn("无法打开浏览器", zap.String("error", err.Error())) zapLogger.Warn("无法打开浏览器", zap.Error(err))
} }
}() }()
@@ -193,7 +202,7 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
func setupStaticFiles(r *gin.Engine) { func setupStaticFiles(r *gin.Engine) {
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist") distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
if err != nil { if err != nil {
zapLogger.Fatal("无法加载前端资源", zap.String("error", err.Error())) zapLogger.Fatal("无法加载前端资源", zap.Error(err))
} }
getContentType := func(path string) string { getContentType := func(path string) string {
@@ -266,7 +275,7 @@ func setupSystray(port int) {
icon, err = embedfs.Assets.ReadFile("assets/icon.png") icon, err = embedfs.Assets.ReadFile("assets/icon.png")
} }
if err != nil { if err != nil {
zapLogger.Error("无法加载托盘图标", zap.String("error", err.Error())) zapLogger.Error("无法加载托盘图标", zap.Error(err))
} }
systray.SetIcon(icon) systray.SetIcon(icon)
systray.SetTitle("Nex Gateway") systray.SetTitle("Nex Gateway")
@@ -287,7 +296,9 @@ func setupSystray(port int) {
for { for {
select { select {
case <-mOpen.ClickedCh: case <-mOpen.ClickedCh:
openBrowser(fmt.Sprintf("http://localhost:%d", port)) if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
zapLogger.Warn("打开浏览器失败", zap.Error(err))
}
case <-mAbout.ClickedCh: case <-mAbout.ClickedCh:
showAbout() showAbout()
case <-mQuit.ClickedCh: case <-mQuit.ClickedCh:
@@ -308,7 +319,9 @@ func doShutdown() {
if server != nil { if server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
server.Shutdown(ctx) if err := server.Shutdown(ctx); err != nil && zapLogger != nil {
zapLogger.Warn("关闭服务器失败", zap.Error(err))
}
} }
if shutdownCancel != nil { if shutdownCancel != nil {
@@ -346,8 +359,8 @@ func (s *SingletonLock) Lock() error {
return nil return nil
} }
func (s *SingletonLock) Unlock() { func (s *SingletonLock) Unlock() error {
s.flock.Unlock() return s.flock.Unlock()
} }
func openBrowser(url string) error { func openBrowser(url string) error {

View File

@@ -21,7 +21,7 @@ func TestCheckPortAvailable(t *testing.T) {
func TestCheckPortOccupied(t *testing.T) { func TestCheckPortOccupied(t *testing.T) {
port := 19827 port := 19827
listener, err := net.Listen("tcp", ":19827") listener, err := net.Listen("tcp", "127.0.0.1:19827")
if err != nil { if err != nil {
t.Fatalf("无法启动测试服务器: %v", err) t.Fatalf("无法启动测试服务器: %v", err)
} }
@@ -47,13 +47,19 @@ func TestCheckPortOccupied(t *testing.T) {
func TestCheckPortAvailableAfterClose(t *testing.T) { func TestCheckPortAvailableAfterClose(t *testing.T) {
port := 19828 port := 19828
listener, err := net.Listen("tcp", ":19828") listener, err := net.Listen("tcp", "127.0.0.1:19828")
if err != nil { if err != nil {
t.Fatalf("无法启动测试服务器: %v", err) t.Fatalf("无法启动测试服务器: %v", err)
} }
server := &http.Server{} server := &http.Server{ReadHeaderTimeout: time.Second}
go server.Serve(listener) defer server.Close()
go func() {
err := server.Serve(listener)
if err != nil && err != http.ErrServerClosed {
t.Errorf("serve failed: %v", err)
}
}()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)

View File

@@ -14,7 +14,11 @@ func TestSingletonLock_FirstLockSuccess(t *testing.T) {
if err := lock.Lock(); err != nil { if err := lock.Lock(); err != nil {
t.Fatalf("首次加锁应成功,但返回错误: %v", err) t.Fatalf("首次加锁应成功,但返回错误: %v", err)
} }
defer lock.Unlock() defer func() {
if err := lock.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
} }
func TestSingletonLock_DuplicateLockFails(t *testing.T) { func TestSingletonLock_DuplicateLockFails(t *testing.T) {
@@ -25,12 +29,18 @@ func TestSingletonLock_DuplicateLockFails(t *testing.T) {
if err := lock1.Lock(); err != nil { if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err) t.Fatalf("首次加锁应成功: %v", err)
} }
defer lock1.Unlock() defer func() {
if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
lock2 := NewSingletonLock(lockPath) lock2 := NewSingletonLock(lockPath)
err := lock2.Lock() err := lock2.Lock()
if err == nil { if err == nil {
lock2.Unlock() if unlockErr := lock2.Unlock(); unlockErr != nil {
t.Fatalf("解锁失败: %v", unlockErr)
}
t.Fatal("重复加锁应失败,但返回 nil") t.Fatal("重复加锁应失败,但返回 nil")
} }
} }
@@ -43,16 +53,22 @@ func TestSingletonLock_UnlockThenRelock(t *testing.T) {
if err := lock1.Lock(); err != nil { if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err) t.Fatalf("首次加锁应成功: %v", err)
} }
lock1.Unlock() if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
lock2 := NewSingletonLock(lockPath) lock2 := NewSingletonLock(lockPath)
if err := lock2.Lock(); err != nil { if err := lock2.Lock(); err != nil {
t.Fatalf("释放后重新加锁应成功: %v", err) t.Fatalf("释放后重新加锁应成功: %v", err)
} }
lock2.Unlock() if err := lock2.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
} }
func TestSingletonLock_UnlockWithoutLock(t *testing.T) { func TestSingletonLock_UnlockWithoutLock(t *testing.T) {
lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock")) lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock"))
lock.Unlock() if err := lock.Unlock(); err != nil {
t.Fatalf("未加锁时解锁失败: %v", err)
}
} }

View File

@@ -6,9 +6,9 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/gin-gonic/gin"
"nex/embedfs" "nex/embedfs"
"github.com/gin-gonic/gin"
) )
func TestSetupStaticFiles(t *testing.T) { func TestSetupStaticFiles(t *testing.T) {

View File

@@ -44,7 +44,11 @@ func main() {
if err != nil { if err != nil {
minimalLogger.Fatal("初始化日志失败", zap.Error(err)) minimalLogger.Fatal("初始化日志失败", zap.Error(err))
} }
defer zapLogger.Sync() defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
cfg.PrintSummary(zapLogger) cfg.PrintSummary(zapLogger)

View File

@@ -1,6 +1,7 @@
package config package config
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@@ -58,7 +59,10 @@ type LogConfig struct {
// DefaultConfig returns default config values // DefaultConfig returns default config values
func DefaultConfig() *Config { func DefaultConfig() *Config {
// Use home dir for default paths // Use home dir for default paths
homeDir, _ := os.UserHomeDir() homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex") nexDir := filepath.Join(homeDir, ".nex")
return &Config{ return &Config{
@@ -97,7 +101,7 @@ func GetConfigDir() (string, error) {
return "", err return "", err
} }
configDir := filepath.Join(homeDir, ".nex") configDir := filepath.Join(homeDir, ".nex")
if err := os.MkdirAll(configDir, 0755); err != nil { if err := os.MkdirAll(configDir, 0o755); err != nil {
return "", err return "", err
} }
return configDir, nil return configDir, nil
@@ -123,7 +127,10 @@ func GetConfigPath() (string, error) {
// setupDefaults 设置默认配置值 // setupDefaults 设置默认配置值
func setupDefaults(v *viper.Viper) { func setupDefaults(v *viper.Viper) {
homeDir, _ := os.UserHomeDir() homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex") nexDir := filepath.Join(homeDir, ".nex")
v.SetDefault("server.port", 9826) v.SetDefault("server.port", 9826)
@@ -177,27 +184,33 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
// 绑定所有 flag 到 viper // 绑定所有 flag 到 viper
// 注意:必须在设置默认值之后绑定 // 注意:必须在设置默认值之后绑定
v.BindPFlag("server.port", flagSet.Lookup("server-port")) bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
v.BindPFlag("server.read_timeout", flagSet.Lookup("server-read-timeout")) bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
v.BindPFlag("server.write_timeout", flagSet.Lookup("server-write-timeout")) bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
v.BindPFlag("database.driver", flagSet.Lookup("database-driver")) bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
v.BindPFlag("database.path", flagSet.Lookup("database-path")) bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
v.BindPFlag("database.host", flagSet.Lookup("database-host")) bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
v.BindPFlag("database.port", flagSet.Lookup("database-port")) bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
v.BindPFlag("database.user", flagSet.Lookup("database-user")) bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
v.BindPFlag("database.password", flagSet.Lookup("database-password")) bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
v.BindPFlag("database.dbname", flagSet.Lookup("database-dbname")) bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
v.BindPFlag("database.max_idle_conns", flagSet.Lookup("database-max-idle-conns")) bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
v.BindPFlag("database.max_open_conns", flagSet.Lookup("database-max-open-conns")) bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
v.BindPFlag("database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime")) bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
v.BindPFlag("log.level", flagSet.Lookup("log-level")) bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
v.BindPFlag("log.path", flagSet.Lookup("log-path")) bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
v.BindPFlag("log.max_size", flagSet.Lookup("log-max-size")) bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
v.BindPFlag("log.max_backups", flagSet.Lookup("log-max-backups")) bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
v.BindPFlag("log.max_age", flagSet.Lookup("log-max-age")) bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
v.BindPFlag("log.compress", flagSet.Lookup("log-compress")) bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
}
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
if err := v.BindPFlag(key, flag); err != nil {
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
}
} }
// setupEnv 绑定环境变量 // setupEnv 绑定环境变量
@@ -218,10 +231,17 @@ func setupConfigFile(v *viper.Viper, configPath string) error {
return appErrors.Wrap(appErrors.ErrInternal, err) return appErrors.Wrap(appErrors.ErrInternal, err)
} }
// 配置文件不存在,创建默认配置文件 // 配置文件不存在,创建默认配置文件
if err := v.SafeWriteConfig(); err != nil { writeErr := v.SafeWriteConfigAs(configPath)
// 忽略写入错误(可能目录已存在等) if writeErr == nil {
return nil return nil
} }
var alreadyExistsErr viper.ConfigFileAlreadyExistsError
if errors.As(writeErr, &alreadyExistsErr) {
return nil
}
return appErrors.Wrap(appErrors.ErrInternal, writeErr)
} }
return nil return nil
} }
@@ -246,7 +266,9 @@ func LoadConfigFromPath(configPath string) (*Config, error) {
setupFlags(v, flagSet) setupFlags(v, flagSet)
// 3. 解析 CLI 参数(忽略错误,因为可能没有参数) // 3. 解析 CLI 参数(忽略错误,因为可能没有参数)
flagSet.Parse(os.Args[1:]) if err := flagSet.Parse(os.Args[1:]); err != nil {
return nil, appErrors.Wrap(appErrors.ErrInvalidRequest, err)
}
// 4. 获取配置文件路径(可能被 --config 参数覆盖) // 4. 获取配置文件路径(可能被 --config 参数覆盖)
if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" { if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" {
@@ -295,11 +317,11 @@ func SaveConfig(cfg *Config) error {
// Ensure directory exists // Ensure directory exists
dir := filepath.Dir(configPath) dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil { if err := os.MkdirAll(dir, 0o755); err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err) return appErrors.Wrap(appErrors.ErrInternal, err)
} }
return os.WriteFile(configPath, data, 0600) return os.WriteFile(configPath, data, 0o600)
} }
// Validate validates the config // Validate validates the config

View File

@@ -236,7 +236,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
configPath := filepath.Join(dir, "config.yaml") configPath := filepath.Join(dir, "config.yaml")
data, err := yaml.Marshal(cfg) data, err := yaml.Marshal(cfg)
require.NoError(t, err) require.NoError(t, err)
err = os.WriteFile(configPath, data, 0644) err = os.WriteFile(configPath, data, 0o600)
require.NoError(t, err) require.NoError(t, err)
// 加载配置 // 加载配置

View File

@@ -47,4 +47,3 @@ func (Model) TableName() string {
func (UsageStats) TableName() string { func (UsageStats) TableName() string {
return "usage_stats" return "usage_stats"
} }

View File

@@ -141,7 +141,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Message: err.Message, Message: err.Message,
}, },
} }
body, _ := json.Marshal(errMsg) body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"type":"error","error":{"type":"internal_error","message":"internal error"}}`), statusCode
}
return body, statusCode return body, statusCode
} }
@@ -235,7 +238,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
return "", nil, err return "", nil, err
} }
rewriteFunc := func(newModel string) ([]byte, error) { rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
} }
return current, rewriteFunc, nil return current, rewriteFunc, nil
@@ -269,7 +276,11 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
switch ifaceType { switch ifaceType {
case conversion.InterfaceTypeChat: case conversion.InterfaceTypeChat:
// Chat 响应必须有 model 字段,存在则改写,不存在则添加 // Chat 响应必须有 model 字段,存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
default: default:
return body, nil return body, nil

View File

@@ -2,6 +2,7 @@ package anthropic
import ( import (
"encoding/json" "encoding/json"
"errors"
"testing" "testing"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
@@ -141,8 +142,8 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
t.Run("解码嵌入请求", func(t *testing.T) { t.Run("解码嵌入请求", func(t *testing.T) {
_, err := a.DecodeEmbeddingRequest([]byte(`{}`)) _, err := a.DecodeEmbeddingRequest([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
@@ -150,24 +151,24 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3") provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider) _, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.True(t, errors.As(err, &convErr))
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("解码嵌入响应", func(t *testing.T) { t.Run("解码嵌入响应", func(t *testing.T) {
_, err := a.DecodeEmbeddingResponse([]byte(`{}`)) _, err := a.DecodeEmbeddingResponse([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("编码嵌入响应", func(t *testing.T) { t.Run("编码嵌入响应", func(t *testing.T) {
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{}) _, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
} }
@@ -178,8 +179,8 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
t.Run("解码重排序请求", func(t *testing.T) { t.Run("解码重排序请求", func(t *testing.T) {
_, err := a.DecodeRerankRequest([]byte(`{}`)) _, err := a.DecodeRerankRequest([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
@@ -187,24 +188,24 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3") provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider) _, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("解码重排序响应", func(t *testing.T) { t.Run("解码重排序响应", func(t *testing.T) {
_, err := a.DecodeRerankResponse([]byte(`{}`)) _, err := a.DecodeRerankResponse([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("编码重排序响应", func(t *testing.T) { t.Run("编码重排序响应", func(t *testing.T) {
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{}) _, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
} }

View File

@@ -28,7 +28,10 @@ func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
var canonicalMsgs []canonical.CanonicalMessage var canonicalMsgs []canonical.CanonicalMessage
for _, msg := range req.Messages { for _, msg := range req.Messages {
decoded := decodeMessage(msg) decoded, err := decodeMessage(msg)
if err != nil {
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析消息内容失败").WithCause(err)
}
canonicalMsgs = append(canonicalMsgs, decoded...) canonicalMsgs = append(canonicalMsgs, decoded...)
} }
@@ -94,10 +97,13 @@ func decodeSystem(system any) any {
} }
// decodeMessage 解码 Anthropic 消息 // decodeMessage 解码 Anthropic 消息
func decodeMessage(msg Message) []canonical.CanonicalMessage { func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
switch msg.Role { switch msg.Role {
case "user": case "user":
blocks := decodeContentBlocks(msg.Content) blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
var toolResults []canonical.ContentBlock var toolResults []canonical.ContentBlock
var others []canonical.ContentBlock var others []canonical.ContentBlock
for _, b := range blocks { for _, b := range blocks {
@@ -117,58 +123,83 @@ func decodeMessage(msg Message) []canonical.CanonicalMessage {
if len(result) == 0 { if len(result) == 0 {
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}}) result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
} }
return result return result, nil
case "assistant": case "assistant":
blocks := decodeContentBlocks(msg.Content) blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
if len(blocks) == 0 { if len(blocks) == 0 {
blocks = append(blocks, canonical.NewTextBlock("")) blocks = append(blocks, canonical.NewTextBlock(""))
} }
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}} return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
} }
return nil return nil, nil
} }
// decodeContentBlocks 解码内容块列表 // decodeContentBlocks 解码内容块列表
func decodeContentBlocks(content any) []canonical.ContentBlock { func decodeContentBlocks(content any) ([]canonical.ContentBlock, error) {
switch v := content.(type) { switch v := content.(type) {
case string: case string:
return []canonical.ContentBlock{canonical.NewTextBlock(v)} return []canonical.ContentBlock{canonical.NewTextBlock(v)}, nil
case []any: case []any:
var blocks []canonical.ContentBlock var blocks []canonical.ContentBlock
for _, item := range v { for _, item := range v {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
block := decodeSingleContentBlock(m) block, err := decodeSingleContentBlock(m)
if err != nil {
return nil, err
}
if block != nil { if block != nil {
blocks = append(blocks, *block) blocks = append(blocks, *block)
} }
} }
} }
if len(blocks) > 0 { if len(blocks) > 0 {
return blocks return blocks, nil
} }
return []canonical.ContentBlock{canonical.NewTextBlock("")} return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
case nil: case nil:
return []canonical.ContentBlock{canonical.NewTextBlock("")} return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
default: default:
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))} return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}, nil
} }
} }
// decodeSingleContentBlock 解码单个内容块 // decodeSingleContentBlock 解码单个内容块
func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock { func decodeSingleContentBlock(m map[string]any) (*canonical.ContentBlock, error) {
t, _ := m["type"].(string) t, ok := m["type"].(string)
if !ok {
return nil, nil
}
switch t { switch t {
case "text": case "text":
text, _ := m["text"].(string) text, ok := m["text"].(string)
return &canonical.ContentBlock{Type: "text", Text: text} if !ok {
text = ""
}
return &canonical.ContentBlock{Type: "text", Text: text}, nil
case "tool_use": case "tool_use":
id, _ := m["id"].(string) id, ok := m["id"].(string)
name, _ := m["name"].(string) if !ok {
input, _ := json.Marshal(m["input"]) id = ""
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input} }
name, ok := m["name"].(string)
if !ok {
name = ""
}
input, err := json.Marshal(m["input"])
if err != nil {
return nil, err
}
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}, nil
case "tool_result": case "tool_result":
toolUseID, _ := m["tool_use_id"].(string) toolUseID, ok := m["tool_use_id"].(string)
if !ok {
toolUseID = ""
}
isErr := false isErr := false
if ie, ok := m["is_error"].(bool); ok { if ie, ok := m["is_error"].(bool); ok {
isErr = ie isErr = ie
@@ -179,7 +210,11 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
case string: case string:
content = json.RawMessage(fmt.Sprintf("%q", cv)) content = json.RawMessage(fmt.Sprintf("%q", cv))
default: default:
content, _ = json.Marshal(cv) encoded, err := json.Marshal(cv)
if err != nil {
return nil, err
}
content = encoded
} }
} else { } else {
content = json.RawMessage(`""`) content = json.RawMessage(`""`)
@@ -189,15 +224,18 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
ToolUseID: toolUseID, ToolUseID: toolUseID,
Content: content, Content: content,
IsError: &isErr, IsError: &isErr,
} }, nil
case "thinking": case "thinking":
thinking, _ := m["thinking"].(string) thinking, ok := m["thinking"].(string)
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking} if !ok {
thinking = ""
}
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}, nil
case "redacted_thinking": case "redacted_thinking":
// 丢弃 // 丢弃
return nil return nil, nil
} }
return nil return nil, nil
} }
// decodeTools 解码工具定义 // decodeTools 解码工具定义
@@ -232,7 +270,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
} }
case map[string]any: case map[string]any:
t, _ := v["type"].(string) t, ok := v["type"].(string)
if !ok {
return nil
}
switch t { switch t {
case "auto": case "auto":
return canonical.NewToolChoiceAuto() return canonical.NewToolChoiceAuto()
@@ -241,7 +282,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
case "any": case "any":
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
case "tool": case "tool":
name, _ := v["name"].(string) name, ok := v["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name) return canonical.NewToolChoiceNamed(name)
} }
} }

View File

@@ -33,7 +33,8 @@ func TestEncodeRequest_Basic(t *testing.T) {
assert.Equal(t, true, result["stream"]) assert.Equal(t, true, result["stream"])
assert.Equal(t, float64(1024), result["max_tokens"]) assert.Equal(t, float64(1024), result["max_tokens"])
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1) assert.Len(t, msgs, 1)
} }
@@ -55,17 +56,20 @@ func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
// tool 消息应被合并到相邻 user 消息 // tool 消息应被合并到相邻 user 消息
foundToolResult := false foundToolResult := false
for _, m := range msgs { for _, m := range msgs {
msgMap := m.(map[string]any) msgMap, ok := m.(map[string]any)
require.True(t, ok)
if msgMap["role"] == "user" { if msgMap["role"] == "user" {
content, ok := msgMap["content"].([]any) content, ok := msgMap["content"].([]any)
if ok { if ok {
for _, c := range content { for _, c := range content {
block := c.(map[string]any) block, ok := c.(map[string]any)
require.True(t, ok)
if block["type"] == "tool_result" { if block["type"] == "tool_result" {
foundToolResult = true foundToolResult = true
} }
@@ -93,8 +97,10 @@ func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
firstMsg := msgs[0].(map[string]any) require.True(t, ok)
firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", firstMsg["role"]) assert.Equal(t, "user", firstMsg["role"])
} }
@@ -140,9 +146,11 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "assistant", result["role"]) assert.Equal(t, "assistant", result["role"])
assert.Equal(t, "end_turn", result["stop_reason"]) assert.Equal(t, "end_turn", result["stop_reason"])
content := result["content"].([]any) content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1) assert.Len(t, content, 1)
block := content[0].(map[string]any) block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", block["type"]) assert.Equal(t, "text", block["type"])
assert.Equal(t, "你好", block["text"]) assert.Equal(t, "你好", block["text"])
} }
@@ -160,10 +168,12 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
data := result["data"].([]any) data, ok := result["data"].([]any)
require.True(t, ok)
assert.Len(t, data, 1) assert.Len(t, data, 1)
model := data[0].(map[string]any) model, ok := data[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "claude-3-opus", model["id"]) assert.Equal(t, "claude-3-opus", model["id"])
// created 应为 RFC3339 格式 // created 应为 RFC3339 格式
createdAt, ok := model["created_at"].(string) createdAt, ok := model["created_at"].(string)
@@ -280,11 +290,14 @@ func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1) assert.Len(t, msgs, 1)
userMsg := msgs[0].(map[string]any) userMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", userMsg["role"]) assert.Equal(t, "user", userMsg["role"])
content := userMsg["content"].([]any) content, ok := userMsg["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 2) assert.Len(t, content, 2)
} }
@@ -321,7 +334,8 @@ func TestEncodeResponse_ReasoningTokens(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any) usage, ok := result["usage"].(map[string]any)
require.True(t, ok)
_, hasReasoning := usage["reasoning_tokens"] _, hasReasoning := usage["reasoning_tokens"]
assert.False(t, hasReasoning) assert.False(t, hasReasoning)
} }
@@ -341,9 +355,11 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
content := result["content"].([]any) content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1) assert.Len(t, content, 1)
block := content[0].(map[string]any) block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", block["type"]) assert.Equal(t, "tool_use", block["type"])
assert.Equal(t, "tool_1", block["id"]) assert.Equal(t, "tool_1", block["id"])
assert.Equal(t, "search", block["name"]) assert.Equal(t, "search", block["name"])

View File

@@ -28,7 +28,7 @@ func NewStreamDecoder() *StreamDecoder {
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
data := rawChunk data := rawChunk
if len(d.utf8Remainder) > 0 { if len(d.utf8Remainder) > 0 {
data = append(d.utf8Remainder, rawChunk...) data = append(append([]byte{}, d.utf8Remainder...), rawChunk...)
d.utf8Remainder = nil d.utf8Remainder = nil
} }
@@ -50,9 +50,10 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
for _, line := range strings.Split(text, "\n") { for _, line := range strings.Split(text, "\n") {
line = strings.TrimRight(line, "\r") line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "event: ") { switch {
case strings.HasPrefix(line, "event: "):
eventType = strings.TrimPrefix(line, "event: ") eventType = strings.TrimPrefix(line, "event: ")
} else if strings.HasPrefix(line, "data: ") { case strings.HasPrefix(line, "data: "):
eventData = strings.TrimPrefix(line, "data: ") eventData = strings.TrimPrefix(line, "data: ")
if eventType != "" && eventData != "" { if eventType != "" && eventData != "" {
chunkEvents := d.processEvent(eventType, []byte(eventData)) chunkEvents := d.processEvent(eventType, []byte(eventData))
@@ -60,8 +61,8 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
} }
eventType = "" eventType = ""
eventData = "" eventData = ""
} else if line == "" { case line == "":
// SSE 事件分隔符 continue
} }
} }

View File

@@ -80,7 +80,8 @@ func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
break break
} }
} }
cb := payload["content_block"].(map[string]any) cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", cb["type"]) assert.Equal(t, "text", cb["type"])
} }
@@ -107,7 +108,8 @@ func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
break break
} }
} }
cb := payload["content_block"].(map[string]any) cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", cb["type"]) assert.Equal(t, "tool_use", cb["type"])
assert.Equal(t, "toolu_1", cb["id"]) assert.Equal(t, "toolu_1", cb["id"])
assert.Equal(t, "search", cb["name"]) assert.Equal(t, "search", cb["name"])
@@ -131,7 +133,8 @@ func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
break break
} }
} }
cb := payload["content_block"].(map[string]any) cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "thinking", cb["type"]) assert.Equal(t, "thinking", cb["type"])
} }
@@ -173,7 +176,8 @@ func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
break break
} }
} }
delta := payload["delta"].(map[string]any) delta, okd := payload["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "end_turn", delta["stop_reason"]) assert.Equal(t, "end_turn", delta["stop_reason"])
} }
@@ -199,7 +203,8 @@ func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
break break
} }
} }
u := payload["usage"].(map[string]any) u, oku := payload["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(88), u["output_tokens"]) assert.Equal(t, float64(88), u["output_tokens"])
} }

View File

@@ -173,13 +173,15 @@ func TestDecodeMessage_UserWithOnlyToolResults(t *testing.T) {
} }
func TestDecodeContentBlocks_Nil(t *testing.T) { func TestDecodeContentBlocks_Nil(t *testing.T) {
blocks := decodeContentBlocks(nil) blocks, err := decodeContentBlocks(nil)
require.NoError(t, err)
assert.Len(t, blocks, 1) assert.Len(t, blocks, 1)
assert.Equal(t, "", blocks[0].Text) assert.Equal(t, "", blocks[0].Text)
} }
func TestDecodeContentBlocks_String(t *testing.T) { func TestDecodeContentBlocks_String(t *testing.T) {
blocks := decodeContentBlocks("hello") blocks, err := decodeContentBlocks("hello")
require.NoError(t, err)
assert.Len(t, blocks, 1) assert.Len(t, blocks, 1)
assert.Equal(t, "hello", blocks[0].Text) assert.Equal(t, "hello", blocks[0].Text)
} }
@@ -217,8 +219,10 @@ func TestEncodeToolChoice(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := encodeToolChoice(tt.choice) result := encodeToolChoice(tt.choice)
assert.Equal(t, tt.want["type"], result.(map[string]any)["type"]) r, ok := result.(map[string]any)
assert.Equal(t, tt.want["name"], result.(map[string]any)["name"]) require.True(t, ok)
assert.Equal(t, tt.want["type"], r["type"])
assert.Equal(t, tt.want["name"], r["name"])
}) })
} }
} }
@@ -315,12 +319,15 @@ func TestEncodeRequest_WithTools(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
tools := result["tools"].([]any) tools, okt := result["tools"].([]any)
require.True(t, okt)
assert.Len(t, tools, 1) assert.Len(t, tools, 1)
tool := tools[0].(map[string]any) tool, okt2 := tools[0].(map[string]any)
require.True(t, okt2)
assert.Equal(t, "search", tool["name"]) assert.Equal(t, "search", tool["name"])
assert.Equal(t, "Search things", tool["description"]) assert.Equal(t, "Search things", tool["description"])
tc := result["tool_choice"].(map[string]any) tc, oktc := result["tool_choice"].(map[string]any)
require.True(t, oktc)
assert.Equal(t, "auto", tc["type"]) assert.Equal(t, "auto", tc["type"])
} }
@@ -366,7 +373,8 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any) usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["input_tokens"]) assert.Equal(t, float64(100), usage["input_tokens"])
assert.Equal(t, float64(30), usage["cache_read_input_tokens"]) assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"]) assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])

View File

@@ -114,7 +114,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
} }
interfaceType := clientAdapter.DetectInterfaceType(nativePath) interfaceType := clientAdapter.DetectInterfaceType(nativePath)
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType) providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
providerHeaders := providerAdapter.BuildHeaders(provider) providerHeaders := providerAdapter.BuildHeaders(provider)
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body) providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
if err != nil { if err != nil {
@@ -122,7 +122,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
} }
return &HTTPRequestSpec{ return &HTTPRequestSpec{
URL: provider.BaseURL + providerUrl, URL: provider.BaseURL + providerURL,
Method: spec.Method, Method: spec.Method,
Headers: providerHeaders, Headers: providerHeaders,
Body: providerBody, Body: providerBody,
@@ -134,25 +134,22 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
if e.IsPassthrough(clientProtocol, providerProtocol) { if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议时最小化改写 model 字段 // Smart Passthrough: 同协议时最小化改写 model 字段
if modelOverride != "" && len(spec.Body) > 0 { if modelOverride != "" && len(spec.Body) > 0 {
adapter, err := e.registry.Get(clientProtocol) adapter, getErr := e.registry.Get(clientProtocol)
if err != nil { if getErr == nil {
return &spec, nil rewrittenBody, rewriteErr := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
} if rewriteErr != nil {
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
if err != nil {
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体", e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
zap.Error(err), zap.Error(rewriteErr),
zap.String("interface", string(interfaceType))) zap.String("interface", string(interfaceType)))
return &spec, nil } else {
}
return &HTTPResponseSpec{ return &HTTPResponseSpec{
StatusCode: spec.StatusCode, StatusCode: spec.StatusCode,
Headers: spec.Headers, Headers: spec.Headers,
Body: rewrittenBody, Body: rewrittenBody,
}, nil }, nil
} }
}
}
return &spec, nil return &spec, nil
} }
@@ -182,12 +179,11 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
if e.IsPassthrough(clientProtocol, providerProtocol) { if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段 // Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
if modelOverride != "" { if modelOverride != "" {
adapter, err := e.registry.Get(clientProtocol) adapter, getErr := e.registry.Get(clientProtocol)
if err != nil { if getErr == nil {
return NewPassthroughStreamConverter(), nil
}
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
} }
}
return NewPassthroughStreamConverter(), nil return NewPassthroughStreamConverter(), nil
} }
@@ -306,7 +302,7 @@ func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapte
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
models, err := providerAdapter.DecodeModelsResponse(body) models, err := providerAdapter.DecodeModelsResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
encoded, err := clientAdapter.EncodeModelsResponse(models) encoded, err := clientAdapter.EncodeModelsResponse(models)
@@ -320,12 +316,12 @@ func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdap
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
info, err := providerAdapter.DecodeModelInfoResponse(body) info, err := providerAdapter.DecodeModelInfoResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
encoded, err := clientAdapter.EncodeModelInfoResponse(info) encoded, err := clientAdapter.EncodeModelInfoResponse(info)
if err != nil { if err != nil {
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
return encoded, nil return encoded, nil
@@ -334,7 +330,7 @@ func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeEmbeddingRequest(body) req, err := clientAdapter.DecodeEmbeddingRequest(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error())) e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.Error(err))
return body, nil return body, nil
} }
return providerAdapter.EncodeEmbeddingRequest(req, provider) return providerAdapter.EncodeEmbeddingRequest(req, provider)
@@ -343,7 +339,7 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeEmbeddingResponse(body) resp, err := providerAdapter.DecodeEmbeddingResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
if modelOverride != "" { if modelOverride != "" {
@@ -355,23 +351,24 @@ func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeRerankRequest(body) req, err := clientAdapter.DecodeRerankRequest(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error())) e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.Error(err))
return body, nil return body, nil
} }
return providerAdapter.EncodeRerankRequest(req, provider) return providerAdapter.EncodeRerankRequest(req, provider)
} }
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeRerankResponse(body) resp, decodeErr := providerAdapter.DecodeRerankResponse(body)
if err != nil { if decodeErr == nil {
return body, nil
}
if modelOverride != "" { if modelOverride != "" {
resp.Model = modelOverride resp.Model = modelOverride
} }
return clientAdapter.EncodeRerankResponse(resp) return clientAdapter.EncodeRerankResponse(resp)
} }
return body, nil
}
// DetectInterfaceType 检测接口类型 // DetectInterfaceType 检测接口类型
func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string) (InterfaceType, error) { func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string) (InterfaceType, error) {
adapter, err := e.registry.Get(clientProtocol) adapter, err := e.registry.Get(clientProtocol)
@@ -391,9 +388,13 @@ func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol stri
"type": "internal_error", "type": "internal_error",
}, },
} }
body, _ := json.Marshal(fallback) body, marshalErr := json.Marshal(fallback)
if marshalErr == nil {
return body, 500, nil return body, 500, nil
} }
return []byte(`{"error":{"message":"internal error","type":"internal_error"}}`), 500, nil
}
body, statusCode := adapter.EncodeError(err) body, statusCode := adapter.EncodeError(err)
return body, statusCode, nil return body, statusCode, nil
} }

View File

@@ -190,7 +190,9 @@ func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel str
// noopStreamDecoder 空流式解码器 // noopStreamDecoder 空流式解码器
type noopStreamDecoder struct{} type noopStreamDecoder struct{}
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil } func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
return nil
}
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil } func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
// noopStreamEncoder 空流式编码器 // noopStreamEncoder 空流式编码器
@@ -615,6 +617,7 @@ func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.Canonical
} }
return nil return nil
} }
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent { func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
if d.flushFn != nil { if d.flushFn != nil {
return d.flushFn() return d.flushFn()
@@ -634,6 +637,7 @@ func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEve
} }
return nil return nil
} }
func (e *engineTestStreamEncoder) Flush() [][]byte { func (e *engineTestStreamEncoder) Flush() [][]byte {
if e.flushFn != nil { if e.flushFn != nil {
return e.flushFn() return e.flushFn()

View File

@@ -138,7 +138,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Code: string(err.Code), Code: string(err.Code),
}, },
} }
body, _ := json.Marshal(errMsg) body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"error":{"message":"internal error","type":"internal_error","code":"INTERNAL_ERROR"}}`), statusCode
}
return body, statusCode return body, statusCode
} }
@@ -248,7 +251,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
return "", nil, err return "", nil, err
} }
rewriteFunc := func(newModel string) ([]byte, error) { rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
} }
return current, rewriteFunc, nil return current, rewriteFunc, nil
@@ -282,12 +289,20 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
switch ifaceType { switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings: case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加 // Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
case conversion.InterfaceTypeRerank: case conversion.InterfaceTypeRerank:
// Rerank 响应:存在 model 字段则改写,不存在则不添加 // Rerank 响应:存在 model 字段则改写,不存在则不添加
if _, exists := m["model"]; exists { if _, exists := m["model"]; exists {
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
} }
return json.Marshal(m) return json.Marshal(m)
default: default:

View File

@@ -215,10 +215,16 @@ func decodeUserContent(content any) []canonical.ContentBlock {
var blocks []canonical.ContentBlock var blocks []canonical.ContentBlock
for _, item := range v { for _, item := range v {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string) t, ok := m["type"].(string)
if !ok {
continue
}
switch t { switch t {
case "text": case "text":
text, _ := m["text"].(string) text, ok := m["text"].(string)
if !ok {
text = ""
}
blocks = append(blocks, canonical.NewTextBlock(text)) blocks = append(blocks, canonical.NewTextBlock(text))
case "image_url": case "image_url":
blocks = append(blocks, canonical.ContentBlock{Type: "image"}) blocks = append(blocks, canonical.ContentBlock{Type: "image"})
@@ -256,13 +262,22 @@ func decodeContentParts(content any) []contentPart {
var result []contentPart var result []contentPart
for _, item := range parts { for _, item := range parts {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string) t, ok := m["type"].(string)
if !ok {
continue
}
switch t { switch t {
case "text": case "text":
text, _ := m["text"].(string) text, ok := m["text"].(string)
if !ok {
text = ""
}
result = append(result, contentPart{Type: "text", Text: text}) result = append(result, contentPart{Type: "text", Text: text})
case "refusal": case "refusal":
refusal, _ := m["refusal"].(string) refusal, ok := m["refusal"].(string)
if !ok {
refusal = ""
}
result = append(result, contentPart{Type: "refusal", Refusal: refusal}) result = append(result, contentPart{Type: "refusal", Refusal: refusal})
} }
} }
@@ -307,21 +322,33 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
} }
case map[string]any: case map[string]any:
t, _ := v["type"].(string) t, ok := v["type"].(string)
if !ok {
return nil
}
switch t { switch t {
case "function": case "function":
if fn, ok := v["function"].(map[string]any); ok { if fn, ok := v["function"].(map[string]any); ok {
name, _ := fn["name"].(string) name, ok := fn["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name) return canonical.NewToolChoiceNamed(name)
} }
case "custom": case "custom":
if custom, ok := v["custom"].(map[string]any); ok { if custom, ok := v["custom"].(map[string]any); ok {
name, _ := custom["name"].(string) name, ok := custom["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name) return canonical.NewToolChoiceNamed(name)
} }
case "allowed_tools": case "allowed_tools":
if at, ok := v["allowed_tools"].(map[string]any); ok { if at, ok := v["allowed_tools"].(map[string]any); ok {
mode, _ := at["mode"].(string) mode, ok := at["mode"].(string)
if !ok {
mode = ""
}
if mode == "required" { if mode == "required" {
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
} }

View File

@@ -45,9 +45,11 @@ func TestEncodeRequest_SystemInjection(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 2) assert.Len(t, msgs, 2)
firstMsg := msgs[0].(map[string]any) firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "system", firstMsg["role"]) assert.Equal(t, "system", firstMsg["role"])
assert.Equal(t, "你是助手", firstMsg["content"]) assert.Equal(t, "你是助手", firstMsg["content"])
} }
@@ -72,12 +74,15 @@ func TestEncodeRequest_ToolCalls(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
assistantMsg := msgs[0].(map[string]any) require.True(t, ok)
assistantMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
toolCalls, ok := assistantMsg["tool_calls"].([]any) toolCalls, ok := assistantMsg["tool_calls"].([]any)
require.True(t, ok) require.True(t, ok)
assert.Len(t, toolCalls, 1) assert.Len(t, toolCalls, 1)
tc := toolCalls[0].(map[string]any) tc, ok := toolCalls[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "call_1", tc["id"]) assert.Equal(t, "call_1", tc["id"])
} }
@@ -115,9 +120,12 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "resp-1", result["id"]) assert.Equal(t, "resp-1", result["id"])
assert.Equal(t, "chat.completion", result["object"]) assert.Equal(t, "chat.completion", result["object"])
choices := result["choices"].([]any) choices, ok := result["choices"].([]any)
choice := choices[0].(map[string]any) require.True(t, ok)
msg := choice["message"].(map[string]any) choice, ok := choices[0].(map[string]any)
require.True(t, ok)
msg, ok := choice["message"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "你好", msg["content"]) assert.Equal(t, "你好", msg["content"])
assert.Equal(t, "stop", choice["finish_reason"]) assert.Equal(t, "stop", choice["finish_reason"])
} }
@@ -137,8 +145,12 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any) choices, okc := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any) require.True(t, okc)
msgMap, okm := choices[0].(map[string]any)
require.True(t, okm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
tcs, ok := msg["tool_calls"].([]any) tcs, ok := msg["tool_calls"].([]any)
require.True(t, ok) require.True(t, ok)
assert.Len(t, tcs, 1) assert.Len(t, tcs, 1)
@@ -158,7 +170,8 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "list", result["object"]) assert.Equal(t, "list", result["object"])
data := result["data"].([]any) data, okd := result["data"].([]any)
require.True(t, okd)
assert.Len(t, data, 2) assert.Len(t, data, 2)
} }
@@ -317,8 +330,12 @@ func TestEncodeResponse_Thinking(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any) choices, okch := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any) require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
assert.Equal(t, "回答", msg["content"]) assert.Equal(t, "回答", msg["content"])
assert.Equal(t, "思考过程", msg["reasoning_content"]) assert.Equal(t, "思考过程", msg["reasoning_content"])
} }

View File

@@ -27,8 +27,12 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
data := strings.TrimPrefix(s, "data: ") data := strings.TrimPrefix(s, "data: ")
data = strings.TrimRight(data, "\n") data = strings.TrimRight(data, "\n")
require.NoError(t, json.Unmarshal([]byte(data), &payload)) require.NoError(t, json.Unmarshal([]byte(data), &payload))
choices := payload["choices"].([]any) choices, okch := payload["choices"].([]any)
delta := choices[0].(map[string]any)["delta"].(map[string]any) require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
delta, okd := msgMap["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "assistant", delta["role"]) assert.Equal(t, "assistant", delta["role"])
} }

View File

@@ -177,7 +177,8 @@ func TestEncodeRerankResponse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "rerank-1", result["model"]) assert.Equal(t, "rerank-1", result["model"])
results := result["results"].([]any) results, okr := result["results"].([]any)
require.True(t, okr)
assert.Len(t, results, 1) assert.Len(t, results, 1)
} }
@@ -373,7 +374,8 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any) usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["prompt_tokens"]) assert.Equal(t, float64(100), usage["prompt_tokens"])
ptd, ok := usage["prompt_tokens_details"].(map[string]any) ptd, ok := usage["prompt_tokens_details"].(map[string]any)
require.True(t, ok) require.True(t, ok)
@@ -412,8 +414,10 @@ func TestEncodeResponse_StopReasons(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any) choices, okch := result["choices"].([]any)
choice := choices[0].(map[string]any) require.True(t, okch)
choice, okc := choices[0].(map[string]any)
require.True(t, okc)
assert.Equal(t, tt.want, choice["finish_reason"]) assert.Equal(t, tt.want, choice["finish_reason"])
}) })
} }

View File

@@ -61,7 +61,7 @@ func initDB(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error)
return gorm.Open(mysql.Open(dsn), gormConfig) return gorm.Open(mysql.Open(dsn), gormConfig)
default: default:
dbDir := filepath.Dir(cfg.Path) dbDir := filepath.Dir(cfg.Path)
if err := os.MkdirAll(dbDir, 0755); err != nil { if err := os.MkdirAll(dbDir, 0o755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err) return nil, fmt.Errorf("创建数据库目录失败: %w", err)
} }
if zapLogger != nil { if zapLogger != nil {
@@ -95,7 +95,9 @@ func runMigrations(db *gorm.DB, driver string, zapLogger *zap.Logger) error {
zap.String("dir", migrationsSubDir)) zap.String("dir", migrationsSubDir))
} }
goose.SetDialect(gooseDialect) if err := goose.SetDialect(gooseDialect); err != nil {
return err
}
if err := goose.Up(sqlDB, migrationsDir); err != nil { if err := goose.Up(sqlDB, migrationsDir); err != nil {
return err return err
} }

View File

@@ -4,11 +4,11 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/config"
) )
func TestInit_SQLite(t *testing.T) { func TestInit_SQLite(t *testing.T) {

View File

@@ -13,4 +13,3 @@ type Provider struct {
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
} }

View File

@@ -6,13 +6,13 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
) )
func TestProviderHandler_CreateProvider_Success(t *testing.T) { func TestProviderHandler_CreateProvider_Success(t *testing.T) {

View File

@@ -9,23 +9,22 @@ import (
"strings" "strings"
"testing" "testing"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"gorm.io/gorm" "gorm.io/gorm"
"nex/backend/internal/domain"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks"
) )
func init() { func init() {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
} }
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) { func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()

View File

@@ -20,7 +20,6 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
if id, ok := requestID.(string); ok { if id, ok := requestID.(string); ok {
requestIDStr = id requestIDStr = id
} }
logger.Info("请求开始", logger.Info("请求开始",
pkglogger.Method(c.Request.Method), pkglogger.Method(c.Request.Method),
pkglogger.Path(path), pkglogger.Path(path),

View File

@@ -4,13 +4,13 @@ import (
"errors" "errors"
"net/http" "net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
) )
// ModelHandler 模型管理处理器 // ModelHandler 模型管理处理器
@@ -58,13 +58,13 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
err := h.modelService.Create(model) err := h.modelService.Create(model)
if err != nil { if err != nil {
if err == appErrors.ErrProviderNotFound { if errors.Is(err, appErrors.ErrProviderNotFound) {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在", "error": "供应商不存在",
}) })
return return
} }
if err == appErrors.ErrDuplicateModel { if errors.Is(err, appErrors.ErrDuplicateModel) {
c.JSON(http.StatusConflict, gin.H{ c.JSON(http.StatusConflict, gin.H{
"error": "同一供应商下模型名称已存在", "error": "同一供应商下模型名称已存在",
"code": appErrors.ErrDuplicateModel.Code, "code": appErrors.ErrDuplicateModel.Code,
@@ -101,7 +101,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
model, err := h.modelService.Get(id) model, err := h.modelService.Get(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到", "error": "模型未找到",
}) })
@@ -166,7 +166,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
err := h.modelService.Delete(id) err := h.modelService.Delete(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到", "error": "模型未找到",
}) })

View File

@@ -4,13 +4,13 @@ import (
"errors" "errors"
"net/http" "net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
) )
// ProviderHandler 供应商管理处理器 // ProviderHandler 供应商管理处理器
@@ -55,7 +55,7 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
err := h.providerService.Create(provider) err := h.providerService.Create(provider)
if err != nil { if err != nil {
if err == appErrors.ErrInvalidProviderID { if errors.Is(err, appErrors.ErrInvalidProviderID) {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"error": appErrors.ErrInvalidProviderID.Message, "error": appErrors.ErrInvalidProviderID.Message,
"code": appErrors.ErrInvalidProviderID.Code, "code": appErrors.ErrInvalidProviderID.Code,
@@ -86,7 +86,7 @@ func (h *ProviderHandler) GetProvider(c *gin.Context) {
provider, err := h.providerService.Get(id) provider, err := h.providerService.Get(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到", "error": "供应商未找到",
}) })
@@ -113,7 +113,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
err := h.providerService.Update(id, req) err := h.providerService.Update(id, req)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到", "error": "供应商未找到",
}) })
@@ -145,7 +145,7 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
err := h.providerService.Delete(id) err := h.providerService.Delete(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到", "error": "供应商未找到",
}) })

View File

@@ -3,19 +3,21 @@ package handler
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"errors"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical" "nex/backend/internal/conversion/canonical"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/service" "nex/backend/internal/service"
"nex/backend/pkg/modelid" "nex/backend/pkg/modelid"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger" pkglogger "nex/backend/pkg/logger"
) )
@@ -159,7 +161,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 转换请求 // 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil { if err != nil {
h.logger.Error("转换请求失败", zap.String("error", err.Error())) h.logger.Error("转换请求失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
} }
@@ -167,7 +169,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 发送请求 // 发送请求
resp, err := h.client.Send(c.Request.Context(), *outSpec) resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil { if err != nil {
h.logger.Error("发送请求失败", zap.String("error", err.Error())) h.logger.Error("发送请求失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
} }
@@ -175,7 +177,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 转换响应,传入 modelOverride跨协议场景覆写 model 字段) // 转换响应,传入 modelOverride跨协议场景覆写 model 字段)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID) convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
if err != nil { if err != nil {
h.logger.Error("转换响应失败", zap.String("error", err.Error())) h.logger.Error("转换响应失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
} }
@@ -191,7 +193,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body) c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
go func() { go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}() }()
} }
@@ -226,34 +228,50 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
for event := range eventChan { for event := range eventChan {
if event.Error != nil { if event.Error != nil {
h.logger.Error("流读取错误", zap.String("error", event.Error.Error())) h.logger.Error("流读取错误", zap.Error(event.Error))
break break
} }
if event.Done { if event.Done {
// flush 转换器 // flush 转换器
chunks := streamConverter.Flush() chunks := streamConverter.Flush()
for _, chunk := range chunks { if err := h.writeStreamChunks(writer, chunks); err != nil {
writer.Write(chunk) h.logger.Warn("流式响应写回失败", zap.Error(err))
writer.Flush()
} }
break break
} }
chunks := streamConverter.ProcessChunk(event.Data) chunks := streamConverter.ProcessChunk(event.Data)
for _, chunk := range chunks { if err := h.writeStreamChunks(writer, chunks); err != nil {
writer.Write(chunk) h.logger.Warn("流式响应写回失败", zap.Error(err))
writer.Flush() break
} }
} }
go func() { go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}() }()
} }
func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error {
for _, chunk := range chunks {
if _, err := writer.Write(chunk); err != nil {
return err
}
if err := writer.Flush(); err != nil {
return err
}
}
return nil
}
// isStreamRequest 判断是否流式请求 // isStreamRequest 判断是否流式请求
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool { func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol) ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol)
if err != nil {
return false
}
if ifaceType != conversion.InterfaceTypeChat { if ifaceType != conversion.InterfaceTypeChat {
return false return false
} }
@@ -272,7 +290,7 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
// 从数据库查询所有启用的模型 // 从数据库查询所有启用的模型
models, err := h.providerService.ListEnabledModels() models, err := h.providerService.ListEnabledModels()
if err != nil { if err != nil {
h.logger.Error("查询启用模型失败", zap.String("error", err.Error())) h.logger.Error("查询启用模型失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
return return
} }
@@ -294,7 +312,7 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
// 使用 adapter 编码返回 // 使用 adapter 编码返回
body, err := adapter.EncodeModelsResponse(modelList) body, err := adapter.EncodeModelsResponse(modelList)
if err != nil { if err != nil {
h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error())) h.logger.Error("编码 Models 响应失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
return return
} }
@@ -342,8 +360,13 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
// writeConversionError 写入转换错误 // writeConversionError 写入转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) { func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
if convErr, ok := err.(*conversion.ConversionError); ok { var convErr *conversion.ConversionError
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol) if errors.As(err, &convErr) {
body, statusCode, encodeErr := h.engine.EncodeError(convErr, clientProtocol)
if encodeErr != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": encodeErr.Error()})
return
}
c.Data(statusCode, "application/json", body) c.Data(statusCode, "application/json", body)
return return
} }

View File

@@ -8,27 +8,26 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks"
) )
func init() { func init() {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
} }
func setupProxyEngine(t *testing.T) *conversion.ConversionEngine { func setupProxyEngine(t *testing.T) *conversion.ConversionEngine {
t.Helper() t.Helper()
registry := conversion.NewMemoryRegistry() registry := conversion.NewMemoryRegistry()
@@ -844,7 +843,8 @@ func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
require.True(t, ok) require.True(t, ok)
assert.Len(t, data, 2) assert.Len(t, data, 2)
first := data[0].(map[string]interface{}) first, ok2 := data[0].(map[string]interface{})
require.True(t, ok2)
assert.Equal(t, "openai/gpt-4", first["id"]) assert.Equal(t, "openai/gpt-4", first["id"])
} }
@@ -918,7 +918,7 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) { client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
var req map[string]interface{} var req map[string]interface{}
json.Unmarshal(spec.Body, &req) require.NoError(t, json.Unmarshal(spec.Body, &req))
assert.Equal(t, "gpt-4", req["model"]) assert.Equal(t, "gpt-4", req["model"])
return &conversion.HTTPResponseSpec{ return &conversion.HTTPResponseSpec{

View File

@@ -5,9 +5,9 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/gin-gonic/gin"
"nex/backend/internal/service" "nex/backend/internal/service"
"github.com/gin-gonic/gin"
) )
// StatsHandler 统计处理器 // StatsHandler 统计处理器

View File

@@ -51,6 +51,7 @@ type Client struct {
} }
// ProviderClient 供应商客户端接口 // ProviderClient 供应商客户端接口
//
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks //go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
type ProviderClient interface { type ProviderClient interface {
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
@@ -141,7 +142,10 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
defer resp.Body.Close() defer resp.Body.Close()
cancel() cancel()
errBody, _ := io.ReadAll(resp.Body) errBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, fmt.Errorf("供应商返回错误: HTTP %d读取错误响应失败: %w", resp.StatusCode, readErr)
}
if len(errBody) > 0 { if len(errBody) > 0 {
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody)) return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody))
} }
@@ -184,7 +188,7 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
if isNetworkError(err) { if isNetworkError(err) {
c.logger.Error("流网络错误", zap.String("error", err.Error())) c.logger.Error("流网络错误", zap.Error(err))
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)} eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
} else { } else {
c.logger.Error("流读取错误", zap.Error(err)) c.logger.Error("流读取错误", zap.Error(err))

View File

@@ -41,7 +41,8 @@ func TestClient_Send_Success(t *testing.T) {
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"test","model":"gpt-4"}`)) _, err := w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
@@ -65,7 +66,8 @@ func TestClient_Send_Success(t *testing.T) {
func TestClient_Send_ErrorResponse(t *testing.T) { func TestClient_Send_ErrorResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":{"message":"Invalid API key"}}`)) _, err := w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
@@ -140,12 +142,15 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n")) _, err = w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n")) _, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
})) }))
@@ -165,11 +170,12 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
var dataEvents [][]byte var dataEvents [][]byte
var doneEvents int var doneEvents int
for event := range eventChan { for event := range eventChan {
if event.Done { switch {
case event.Done:
doneEvents++ doneEvents++
} else if event.Error != nil { case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error) t.Fatalf("unexpected error: %v", event.Error)
} else { default:
dataEvents = append(dataEvents, event.Data) dataEvents = append(dataEvents, event.Data)
} }
} }
@@ -215,7 +221,8 @@ func TestClient_Send_EmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method) assert.Equal(t, "GET", r.Method)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"result":"ok"}`)) _, err := w.Write([]byte(`{"result":"ok"}`))
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
@@ -238,10 +245,12 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n")) _, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
})) }))
@@ -261,11 +270,12 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
var dataCount int var dataCount int
var doneCount int var doneCount int
for event := range eventChan { for event := range eventChan {
if event.Done { switch {
case event.Done:
doneCount++ doneCount++
} else if event.Error != nil { case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error) t.Fatalf("unexpected error: %v", event.Error)
} else { default:
dataCount++ dataCount++
} }
} }
@@ -279,10 +289,12 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n")) _, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
})) }))
@@ -364,13 +376,14 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
if hijacker, ok := w.(http.Hijacker); ok { if hijacker, ok := w.(http.Hijacker); ok {
conn, _, _ := hijacker.Hijack() conn, _, _ := hijacker.Hijack()
if conn != nil { if conn != nil {
conn.Close() require.NoError(t, conn.Close())
} }
} }
})) }))

View File

@@ -3,10 +3,11 @@ package repository
import ( import (
"time" "time"
"gorm.io/gorm"
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
) )

View File

@@ -3,13 +3,13 @@ package repository
import ( import (
"testing" "testing"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/gorm" "gorm.io/gorm"
testHelpers "nex/backend/tests" testHelpers "nex/backend/tests"
"nex/backend/internal/domain"
) )
func setupTestDB(t *testing.T) *gorm.DB { func setupTestDB(t *testing.T) *gorm.DB {

View File

@@ -3,11 +3,11 @@ package repository
import ( import (
"time" "time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type statsRepository struct { type statsRepository struct {
@@ -19,8 +19,8 @@ func NewStatsRepository(db *gorm.DB) StatsRepository {
} }
func (r *statsRepository) Record(providerID, modelName string) error { func (r *statsRepository) Record(providerID, modelName string) error {
today := time.Now().Format("2006-01-02") now := time.Now()
todayTime, _ := time.Parse("2006-01-02", today) todayTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
stats := config.UsageStats{ stats := config.UsageStats{
ProviderID: providerID, ProviderID: providerID,

View File

@@ -1,11 +1,15 @@
package service package service
import ( import (
"github.com/google/uuid" "errors"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"github.com/google/uuid"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
) )
type modelService struct { type modelService struct {
@@ -108,8 +112,12 @@ func (s *modelService) Delete(id string) error {
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error { func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName) existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil // 未找到,不重复 return nil // 未找到,不重复
} }
return err
}
if excludeID != "" && existing.ID == excludeID { if excludeID != "" && existing.ID == excludeID {
return nil // 排除自身 return nil // 排除自身
} }

View File

@@ -3,10 +3,10 @@ package service
import ( import (
"strings" "strings"
"nex/backend/pkg/modelid"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/pkg/modelid"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
) )

View File

@@ -4,10 +4,11 @@ import (
"strings" "strings"
"sync" "sync"
"go.uber.org/zap"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger" pkglogger "nex/backend/pkg/logger"
) )
@@ -34,7 +35,9 @@ func NewRoutingCache(
func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) { func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
if v, ok := c.providers.Load(id); ok { if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil if provider, ok := v.(*domain.Provider); ok {
return provider, nil
}
} }
provider, err := c.providerRepo.GetByID(id) provider, err := c.providerRepo.GetByID(id)
@@ -43,7 +46,9 @@ func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
} }
if v, ok := c.providers.Load(id); ok { if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil if provider, ok := v.(*domain.Provider); ok {
return provider, nil
}
} }
c.providers.Store(id, provider) c.providers.Store(id, provider)
@@ -54,7 +59,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
key := providerID + "/" + modelName key := providerID + "/" + modelName
if v, ok := c.models.Load(key); ok { if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil if model, ok := v.(*domain.Model); ok {
return model, nil
}
} }
model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName) model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName)
@@ -63,7 +70,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
} }
if v, ok := c.models.Load(key); ok { if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil if model, ok := v.(*domain.Model); ok {
return model, nil
}
} }
c.models.Store(key, model) c.models.Store(key, model)
@@ -97,7 +106,12 @@ func (c *RoutingCache) invalidateModelsByProvider(providerID string) {
prefix := providerID + "/" prefix := providerID + "/"
count := 0 count := 0
c.models.Range(func(key, value interface{}) bool { c.models.Range(func(key, value interface{}) bool {
if strings.HasPrefix(key.(string), prefix) { keyStr, ok := key.(string)
if !ok {
return true
}
if strings.HasPrefix(keyStr, prefix) {
c.models.Delete(key) c.models.Delete(key)
count++ count++
} }

View File

@@ -5,11 +5,11 @@ import (
"sync" "sync"
"testing" "testing"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/domain"
) )
type mockModelRepo struct { type mockModelRepo struct {
@@ -189,7 +189,8 @@ func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
var openaiCount, anthropicCount int var openaiCount, anthropicCount int
cache.models.Range(func(key, value interface{}) bool { cache.models.Range(func(key, value interface{}) bool {
if key.(string) == "anthropic/claude" { keyStr, ok := key.(string)
if ok && keyStr == "anthropic/claude" {
anthropicCount++ anthropicCount++
} }
return true return true

View File

@@ -1,9 +1,8 @@
package service package service
import ( import (
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain" "nex/backend/internal/domain"
appErrors "nex/backend/pkg/errors"
) )
type routingService struct { type routingService struct {

View File

@@ -3,12 +3,12 @@ package service
import ( import (
"testing" "testing"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
) )
func TestProviderService_Update(t *testing.T) { func TestProviderService_Update(t *testing.T) {
@@ -133,7 +133,9 @@ func TestStatsService_Aggregate_Default(t *testing.T) {
totalCount := 0 totalCount := 0
for _, r := range result { for _, r := range result {
totalCount += r["request_count"].(int) count, ok := r["request_count"].(int)
require.True(t, ok)
totalCount += count
} }
assert.Equal(t, 15, totalCount) assert.Equal(t, 15, totalCount)
} }

View File

@@ -5,6 +5,9 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -13,8 +16,6 @@ import (
testHelpers "nex/backend/tests" testHelpers "nex/backend/tests"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
) )
@@ -318,7 +319,8 @@ func TestStatsService_Aggregate_ByModel(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db) statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, zap.NewNop()); svc := NewStatsService(statsRepo, buffer) buffer := NewStatsBuffer(statsRepo, zap.NewNop())
svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "model") result := svc.Aggregate(tt.stats, "model")
@@ -379,7 +381,8 @@ func TestStatsService_Aggregate_ByDate(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db) statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, zap.NewNop()); svc := NewStatsService(statsRepo, buffer) buffer := NewStatsBuffer(statsRepo, zap.NewNop())
svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "date") result := svc.Aggregate(tt.stats, "date")

View File

@@ -6,9 +6,10 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"nex/backend/internal/repository"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/repository"
pkglogger "nex/backend/pkg/logger" pkglogger "nex/backend/pkg/logger"
) )
@@ -67,13 +68,21 @@ func (b *StatsBuffer) Increment(providerID, modelName string) {
var counter *int64 var counter *int64
if v, ok := b.counters.Load(key); ok { if v, ok := b.counters.Load(key); ok {
counter = v.(*int64) if existing, ok := v.(*int64); ok {
counter = existing
} else {
return
}
} else { } else {
val := int64(0) val := int64(0)
counter = &val counter = &val
actual, loaded := b.counters.LoadOrStore(key, counter) actual, loaded := b.counters.LoadOrStore(key, counter)
if loaded { if loaded {
counter = actual.(*int64) existing, ok := actual.(*int64)
if !ok {
return
}
counter = existing
} }
} }
@@ -117,13 +126,20 @@ func (b *StatsBuffer) flush() {
var entries []statEntry var entries []statEntry
b.counters.Range(func(key, value interface{}) bool { b.counters.Range(func(key, value interface{}) bool {
keyStr := key.(string) keyStr, ok := key.(string)
if !ok {
return true
}
parts := strings.Split(keyStr, "/") parts := strings.Split(keyStr, "/")
if len(parts) != 3 { if len(parts) != 3 {
return true return true
} }
counter := value.(*int64) counter, ok := value.(*int64)
if !ok {
return true
}
count := atomic.SwapInt64(counter, 0) count := atomic.SwapInt64(counter, 0)
if count > 0 { if count > 0 {
@@ -143,8 +159,17 @@ func (b *StatsBuffer) flush() {
success := 0 success := 0
for _, entry := range entries { for _, entry := range entries {
date, _ := time.Parse("2006-01-02", entry.date) date, err := time.Parse("2006-01-02", entry.date)
err := b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count)) if err != nil {
b.logger.Error("解析统计日期失败",
zap.String("provider_id", entry.providerID),
zap.String("model_name", entry.modelName),
zap.String("date", entry.date),
zap.Error(err))
continue
}
err = b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
if err != nil { if err != nil {
b.logger.Error("批量更新统计失败", b.logger.Error("批量更新统计失败",
zap.String("provider_id", entry.providerID), zap.String("provider_id", entry.providerID),
@@ -154,9 +179,11 @@ func (b *StatsBuffer) flush() {
key := entry.providerID + "/" + entry.modelName + "/" + entry.date key := entry.providerID + "/" + entry.modelName + "/" + entry.date
if v, ok := b.counters.Load(key); ok { if v, ok := b.counters.Load(key); ok {
counter := v.(*int64) counter, ok := v.(*int64)
if ok {
atomic.AddInt64(counter, entry.count) atomic.AddInt64(counter, entry.count)
} }
}
} else { } else {
success++ success++
} }

View File

@@ -7,10 +7,10 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/domain"
) )
type mockStatsRepo struct { type mockStatsRepo struct {
@@ -58,8 +58,10 @@ func TestStatsBuffer_Increment(t *testing.T) {
var count int64 var count int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
if ok {
count += atomic.LoadInt64(counter) count += atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(3), count) assert.Equal(t, int64(3), count)
@@ -82,8 +84,10 @@ func TestStatsBuffer_ConcurrentIncrement(t *testing.T) {
var count int64 var count int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
if ok {
count = atomic.LoadInt64(counter) count = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(100), count) assert.Equal(t, int64(100), count)
@@ -161,8 +165,10 @@ func TestStatsBuffer_SwapInt64(t *testing.T) {
var beforeCount int64 var beforeCount int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
if ok {
beforeCount = atomic.LoadInt64(counter) beforeCount = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(2), beforeCount) assert.Equal(t, int64(2), beforeCount)
@@ -171,8 +177,10 @@ func TestStatsBuffer_SwapInt64(t *testing.T) {
var afterCount int64 var afterCount int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
if ok {
afterCount = atomic.LoadInt64(counter) afterCount = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(0), afterCount) assert.Equal(t, int64(0), afterCount)
@@ -190,8 +198,10 @@ func TestStatsBuffer_FailRetry(t *testing.T) {
var count int64 var count int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
if ok {
count = atomic.LoadInt64(counter) count = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(2), count) assert.Equal(t, int64(2), count)

View File

@@ -1,6 +1,7 @@
package errors package errors
import ( import (
stderrors "errors"
"fmt" "fmt"
"net/http" "net/http"
) )
@@ -70,22 +71,11 @@ func AsAppError(err error) (*AppError, bool) {
if err == nil { if err == nil {
return nil, false return nil, false
} }
var appErr *AppError var appErr *AppError
if ok := is(err, &appErr); ok { if !stderrors.As(err, &appErr) {
return appErr, true
}
return nil, false return nil, false
} }
func is(err error, target interface{}) bool { return appErr, true
// 简单的类型断言
if e, ok := err.(*AppError); ok {
// 直接赋值
switch t := target.(type) {
case **AppError:
*t = e
return true
}
}
return false
} }

View File

@@ -104,7 +104,8 @@ func TestPredefinedErrors(t *testing.T) {
func TestAsAppError(t *testing.T) { func TestAsAppError(t *testing.T) {
t.Run("nil输入", func(t *testing.T) { t.Run("nil输入", func(t *testing.T) {
_, ok := AsAppError(nil) appErr, ok := AsAppError(nil)
assert.Nil(t, appErr)
assert.False(t, ok) assert.False(t, ok)
}) })
@@ -122,7 +123,8 @@ func TestAsAppError(t *testing.T) {
}) })
t.Run("非AppError类型", func(t *testing.T) { t.Run("非AppError类型", func(t *testing.T) {
_, ok := AsAppError(errors.New("普通错误")) appErr, ok := AsAppError(errors.New("普通错误"))
assert.Nil(t, appErr)
assert.False(t, ok) assert.False(t, ok)
}) })
} }

View File

@@ -19,7 +19,7 @@ func TestNew_StdoutOnly(t *testing.T) {
func TestNew_WithFileOutput(t *testing.T) { func TestNew_WithFileOutput(t *testing.T) {
dir := filepath.Join(os.TempDir(), "nex-logger-test") dir := filepath.Join(os.TempDir(), "nex-logger-test")
os.MkdirAll(dir, 0755) require.NoError(t, os.MkdirAll(dir, 0o755))
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
logger, err := New(Config{ logger, err := New(Config{

View File

@@ -6,10 +6,10 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"nex/backend/internal/config"
) )
func TestLoadConfig_DefaultValues(t *testing.T) { func TestLoadConfig_DefaultValues(t *testing.T) {
@@ -72,7 +72,7 @@ log:
max_age: 7 max_age: 7
compress: false compress: false
` `
err := os.WriteFile(configPath, []byte(yamlContent), 0644) err := os.WriteFile(configPath, []byte(yamlContent), 0o600)
require.NoError(t, err) require.NoError(t, err)
cfg, err := config.LoadConfigFromPath(configPath) cfg, err := config.LoadConfigFromPath(configPath)
@@ -103,7 +103,7 @@ server:
log: log:
level: warn level: warn
` `
err := os.WriteFile(configPath, []byte(yamlContent), 0644) err := os.WriteFile(configPath, []byte(yamlContent), 0o600)
require.NoError(t, err) require.NoError(t, err)
t.Setenv("NEX_SERVER_PORT", "9000") t.Setenv("NEX_SERVER_PORT", "9000")
@@ -147,7 +147,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
} }
defer func() { defer func() {
if originalConfig != nil { if originalConfig != nil {
_ = os.WriteFile(configPath, originalConfig, 0644) require.NoError(t, os.WriteFile(configPath, originalConfig, 0o600))
} }
}() }()

View File

@@ -11,20 +11,21 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/provider"
"nex/backend/internal/repository"
"nex/backend/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"gorm.io/gorm" "gorm.io/gorm"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
openaiConv "nex/backend/internal/conversion/openai" openaiConv "nex/backend/internal/conversion/openai"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/provider"
"nex/backend/internal/repository"
"nex/backend/internal/service"
) )
func init() { func init() {
@@ -39,7 +40,8 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 默认返回成功,由各测试 case 覆盖 // 默认返回成功,由各测试 case 覆盖
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"error":"not mocked"}`)) _, err := w.Write([]byte(`{"error":"not mocked"}`))
require.NoError(t, err)
})) }))
db := setupTestDB(t) db := setupTestDB(t)
@@ -124,7 +126,6 @@ func createProviderAndModel(t *testing.T, r *gin.Engine, providerID, protocol, m
require.Equal(t, 201, w.Code) require.Equal(t, 201, w.Code)
modelBody, _ := json.Marshal(map[string]string{ modelBody, _ := json.Marshal(map[string]string{
"provider_id": providerID, "provider_id": providerID,
"model_name": modelName, "model_name": modelName,
}) })
@@ -143,9 +144,10 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
// 配置上游返回 Anthropic 格式响应 // 配置上游返回 Anthropic 格式响应
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求被转换为 Anthropic 格式 // 验证请求被转换为 Anthropic 格式
body, _ := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any var req map[string]any
json.Unmarshal(body, &req) require.NoError(t, json.Unmarshal(body, &req))
assert.Equal(t, "/v1/messages", r.URL.Path) assert.Equal(t, "/v1/messages", r.URL.Path)
assert.Contains(t, r.Header.Get("Content-Type"), "application/json") assert.Contains(t, r.Header.Get("Content-Type"), "application/json")
@@ -166,7 +168,7 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
}, },
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp) require.NoError(t, json.NewEncoder(w).Encode(resp))
}) })
createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL) createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
@@ -189,13 +191,16 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var resp map[string]any var resp map[string]any
json.Unmarshal(w.Body.Bytes(), &resp) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "chat.completion", resp["object"]) assert.Equal(t, "chat.completion", resp["object"])
choices := resp["choices"].([]any) choices, ok := resp["choices"].([]any)
require.True(t, ok)
require.Len(t, choices, 1) require.Len(t, choices, 1)
choice := choices[0].(map[string]any) choice, ok := choices[0].(map[string]any)
msg := choice["message"].(map[string]any) require.True(t, ok)
msg, ok := choice["message"].(map[string]any)
require.True(t, ok)
assert.Contains(t, msg["content"], "Hello from Anthropic!") assert.Contains(t, msg["content"], "Hello from Anthropic!")
} }
@@ -203,9 +208,10 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
r, _, upstream := setupConversionTest(t) r, _, upstream := setupConversionTest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any var req map[string]any
json.Unmarshal(body, &req) require.NoError(t, json.Unmarshal(body, &req))
assert.Equal(t, "/chat/completions", r.URL.Path) assert.Equal(t, "/chat/completions", r.URL.Path)
assert.Contains(t, r.Header.Get("Authorization"), "Bearer test-key") assert.Contains(t, r.Header.Get("Authorization"), "Bearer test-key")
@@ -229,7 +235,7 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
}, },
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp) require.NoError(t, json.NewEncoder(w).Encode(resp))
}) })
createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL) createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
@@ -252,12 +258,14 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var resp map[string]any var resp map[string]any
json.Unmarshal(w.Body.Bytes(), &resp) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "message", resp["type"]) assert.Equal(t, "message", resp["type"])
content := resp["content"].([]any) content, ok := resp["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1) require.Len(t, content, 1)
block := content[0].(map[string]any) block, ok2 := content[0].(map[string]any)
require.True(t, ok2)
assert.Contains(t, block["text"], "Hello from OpenAI!") assert.Contains(t, block["text"], "Hello from OpenAI!")
} }
@@ -269,15 +277,17 @@ func TestConversion_OpenAIToOpenAI_Passthrough(t *testing.T) {
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/chat/completions", r.URL.Path) assert.Equal(t, "/chat/completions", r.URL.Path)
body, _ := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any var req map[string]any
json.Unmarshal(body, &req) require.NoError(t, json.Unmarshal(body, &req))
// Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名 // Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名
assert.Equal(t, "gpt-4", req["model"]) assert.Equal(t, "gpt-4", req["model"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
// 上游返回上游模型名 // 上游返回上游模型名
w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`)) _, err = w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`))
require.NoError(t, err)
}) })
createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL) createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
@@ -304,15 +314,17 @@ func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) {
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/messages", r.URL.Path) assert.Equal(t, "/v1/messages", r.URL.Path)
body, _ := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any var req map[string]any
json.Unmarshal(body, &req) require.NoError(t, json.Unmarshal(body, &req))
// Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名 // Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名
assert.Equal(t, "claude-3-opus", req["model"]) assert.Equal(t, "claude-3-opus", req["model"])
// 上游返回上游模型名 // 上游返回上游模型名
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`)) _, err = w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`))
require.NoError(t, err)
}) })
createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL) createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
@@ -352,7 +364,8 @@ func TestConversion_OpenAIToAnthropic_Stream(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
if f, ok := w.(http.Flusher); ok { if f, ok := w.(http.Flusher); ok {
f.Flush() f.Flush()
} }
@@ -393,7 +406,8 @@ func TestConversion_AnthropicToOpenAI_Stream(t *testing.T) {
"data: [DONE]\n\n", "data: [DONE]\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
if f, ok := w.(http.Flusher); ok { if f, ok := w.(http.Flusher); ok {
f.Flush() f.Flush()
} }
@@ -447,11 +461,13 @@ func TestConversion_Models_CrossProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
var anthropicResp map[string]any var anthropicResp map[string]any
json.Unmarshal(anthropicBody, &anthropicResp) require.NoError(t, json.Unmarshal(anthropicBody, &anthropicResp))
data := anthropicResp["data"].([]any) data, okd := anthropicResp["data"].([]any)
require.True(t, okd)
assert.Len(t, data, 2) assert.Len(t, data, 2)
first := data[0].(map[string]any) first, okf := data[0].(map[string]any)
require.True(t, okf)
assert.Equal(t, "gpt-4", first["id"]) assert.Equal(t, "gpt-4", first["id"])
assert.Equal(t, "model", first["type"]) assert.Equal(t, "model", first["type"])
@@ -466,11 +482,12 @@ func TestConversion_Models_CrossProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
var openaiResp map[string]any var openaiResp map[string]any
json.Unmarshal(openaiBody, &err) require.NoError(t, json.Unmarshal(openaiBody, &openaiResp))
json.Unmarshal(openaiBody, &openaiResp) oaiData, oki := openaiResp["data"].([]any)
oaiData := openaiResp["data"].([]any) require.True(t, oki)
assert.Len(t, oaiData, 1) assert.Len(t, oaiData, 1)
firstOai := oaiData[0].(map[string]any) firstOai, okf2 := oaiData[0].(map[string]any)
require.True(t, okf2)
assert.Equal(t, "claude-3-opus", firstOai["id"]) assert.Equal(t, "claude-3-opus", firstOai["id"])
} }
@@ -537,7 +554,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
require.Equal(t, 201, w.Code) require.Equal(t, 201, w.Code)
var created map[string]any var created map[string]any
json.Unmarshal(w.Body.Bytes(), &created) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &created))
assert.Equal(t, "sk-test", created["api_key"]) assert.Equal(t, "sk-test", created["api_key"])
// 获取时应包含 protocol // 获取时应包含 protocol
@@ -547,7 +564,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var fetched map[string]any var fetched map[string]any
json.Unmarshal(w.Body.Bytes(), &fetched) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &fetched))
assert.Equal(t, "anthropic", fetched["protocol"]) assert.Equal(t, "anthropic", fetched["protocol"])
} }
@@ -570,11 +587,13 @@ func TestConversion_ProviderDefaultProtocol(t *testing.T) {
require.Equal(t, 201, w.Code) require.Equal(t, 201, w.Code)
var created map[string]any var created map[string]any
json.Unmarshal(w.Body.Bytes(), &created) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &created))
assert.Equal(t, "openai", created["protocol"]) assert.Equal(t, "openai", created["protocol"])
} }
// Suppress unused imports // Suppress unused imports
var _ = fmt.Sprintf var (
var _ = strings.Contains _ = fmt.Sprintf
var _ = time.Second _ = strings.Contains
_ = time.Second
)

View File

@@ -12,19 +12,20 @@ import (
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic" "nex/backend/internal/conversion/anthropic"
openaiConv "nex/backend/internal/conversion/openai"
"nex/backend/internal/handler" "nex/backend/internal/handler"
"nex/backend/internal/handler/middleware" "nex/backend/internal/handler/middleware"
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/internal/service" "nex/backend/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
openaiConv "nex/backend/internal/conversion/openai"
) )
func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) { func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
@@ -33,7 +34,8 @@ func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"error":"not mocked"}`)) _, err := w.Write([]byte(`{"error":"not mocked"}`))
require.NoError(t, err)
})) }))
db := setupTestDB(t) db := setupTestDB(t)
@@ -115,11 +117,12 @@ func parseSSEEvents(body string) []map[string]string {
var currentEvent, currentData string var currentEvent, currentData string
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if strings.HasPrefix(line, "event: ") { switch {
case strings.HasPrefix(line, "event: "):
currentEvent = strings.TrimPrefix(line, "event: ") currentEvent = strings.TrimPrefix(line, "event: ")
} else if strings.HasPrefix(line, "data: ") { case strings.HasPrefix(line, "data: "):
currentData = strings.TrimPrefix(line, "data: ") currentData = strings.TrimPrefix(line, "data: ")
} else if line == "" && (currentEvent != "" || currentData != "") { case line == "" && (currentEvent != "" || currentData != ""):
events = append(events, map[string]string{ events = append(events, map[string]string{
"event": currentEvent, "event": currentEvent,
"data": currentData, "data": currentData,
@@ -157,7 +160,7 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) {
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.Equal(t, "/chat/completions", req.URL.Path) assert.Equal(t, "/chat/completions", req.URL.Path)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-001", "id": "chatcmpl-e2e-001",
"object": "chat.completion", "object": "chat.completion",
"created": 1700000000, "created": 1700000000,
@@ -171,7 +174,7 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) {
"usage": map[string]any{ "usage": map[string]any{
"prompt_tokens": 15, "completion_tokens": 10, "total_tokens": 25, "prompt_tokens": 15, "completion_tokens": 10, "total_tokens": 25,
}, },
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -210,21 +213,23 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) {
func TestE2E_OpenAI_NonStream_MultiTurn(t *testing.T) { func TestE2E_OpenAI_NonStream_MultiTurn(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
msgs := reqBody["messages"].([]any) msgs, ok := reqBody["messages"].([]any)
require.True(t, ok)
assert.GreaterOrEqual(t, len(msgs), 3) assert.GreaterOrEqual(t, len(msgs), 3)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-002", "object": "chat.completion", "created": 1700000001, "model": "gpt-4o", "id": "chatcmpl-e2e-002", "object": "chat.completion", "created": 1700000001, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "message": map[string]any{"role": "assistant", "content": "Go语言的interface是隐式实现的。"}, "index": 0, "message": map[string]any{"role": "assistant", "content": "Go语言的interface是隐式实现的。"},
"finish_reason": "stop", "logprobs": nil, "finish_reason": "stop", "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120}, "usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -252,7 +257,7 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-004", "object": "chat.completion", "created": 1700000003, "model": "gpt-4o", "id": "chatcmpl-e2e-004", "object": "chat.completion", "created": 1700000003, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
@@ -272,7 +277,7 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 80, "completion_tokens": 18, "total_tokens": 98}, "usage": map[string]any{"prompt_tokens": 80, "completion_tokens": 18, "total_tokens": 98},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -319,7 +324,7 @@ func TestE2E_OpenAI_NonStream_MaxTokens_Length(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-014", "object": "chat.completion", "created": 1700000014, "model": "gpt-4o", "id": "chatcmpl-e2e-014", "object": "chat.completion", "created": 1700000014, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
@@ -328,7 +333,7 @@ func TestE2E_OpenAI_NonStream_MaxTokens_Length(t *testing.T) {
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 30, "total_tokens": 50}, "usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 30, "total_tokens": 50},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -353,7 +358,7 @@ func TestE2E_OpenAI_NonStream_UsageWithReasoning(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-022", "object": "chat.completion", "created": 1700000022, "model": "o3", "id": "chatcmpl-e2e-022", "object": "chat.completion", "created": 1700000022, "model": "o3",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
@@ -365,7 +370,7 @@ func TestE2E_OpenAI_NonStream_UsageWithReasoning(t *testing.T) {
"prompt_tokens": 35, "completion_tokens": 48, "total_tokens": 83, "prompt_tokens": 35, "completion_tokens": 48, "total_tokens": 83,
"completion_tokens_details": map[string]any{"reasoning_tokens": 20}, "completion_tokens_details": map[string]any{"reasoning_tokens": 20},
}, },
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "o3", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "o3", upstream.URL)
@@ -393,7 +398,7 @@ func TestE2E_OpenAI_NonStream_Refusal(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-007", "object": "chat.completion", "created": 1700000007, "model": "gpt-4o", "id": "chatcmpl-e2e-007", "object": "chat.completion", "created": 1700000007, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
@@ -406,7 +411,7 @@ func TestE2E_OpenAI_NonStream_Refusal(t *testing.T) {
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 35, "total_tokens": 47}, "usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 35, "total_tokens": 47},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -569,14 +574,14 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_001", "type": "message", "role": "assistant", "id": "msg_e2e_001", "type": "message", "role": "assistant",
"content": []map[string]any{ "content": []map[string]any{
{"type": "text", "text": "你好我是Claude由Anthropic开发的AI助手。"}, {"type": "text", "text": "你好我是Claude由Anthropic开发的AI助手。"},
}, },
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 15, "output_tokens": 25}, "usage": map[string]any{"input_tokens": 15, "output_tokens": 25},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -611,18 +616,19 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) {
func TestE2E_Anthropic_NonStream_WithSystem(t *testing.T) { func TestE2E_Anthropic_NonStream_WithSystem(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
assert.NotNil(t, reqBody["system"]) assert.NotNil(t, reqBody["system"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_003", "type": "message", "role": "assistant", "id": "msg_e2e_003", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "递归是函数调用自身。"}}, "content": []map[string]any{{"type": "text", "text": "递归是函数调用自身。"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 30, "output_tokens": 15}, "usage": map[string]any{"input_tokens": 30, "output_tokens": 15},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -643,7 +649,7 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_009", "type": "message", "role": "assistant", "id": "msg_e2e_009", "type": "message", "role": "assistant",
"content": []map[string]any{{ "content": []map[string]any{{
"type": "tool_use", "id": "toolu_e2e_009", "name": "get_weather", "type": "tool_use", "id": "toolu_e2e_009", "name": "get_weather",
@@ -651,7 +657,7 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
}}, }},
"model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 180, "output_tokens": 42}, "usage": map[string]any{"input_tokens": 180, "output_tokens": 42},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -689,7 +695,7 @@ func TestE2E_Anthropic_NonStream_Thinking(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_018", "type": "message", "role": "assistant", "id": "msg_e2e_018", "type": "message", "role": "assistant",
"content": []map[string]any{ "content": []map[string]any{
{"type": "thinking", "thinking": "这是一个逻辑推理问题..."}, {"type": "thinking", "thinking": "这是一个逻辑推理问题..."},
@@ -697,7 +703,7 @@ func TestE2E_Anthropic_NonStream_Thinking(t *testing.T) {
}, },
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 95, "output_tokens": 280}, "usage": map[string]any{"input_tokens": 95, "output_tokens": 280},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -724,12 +730,12 @@ func TestE2E_Anthropic_NonStream_MaxTokens(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_016", "type": "message", "role": "assistant", "id": "msg_e2e_016", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "人工智能起源于..."}}, "content": []map[string]any{{"type": "text", "text": "人工智能起源于..."}},
"model": "claude-opus-4-7", "stop_reason": "max_tokens", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "max_tokens", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 22, "output_tokens": 20}, "usage": map[string]any{"input_tokens": 22, "output_tokens": 20},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -752,12 +758,12 @@ func TestE2E_Anthropic_NonStream_StopSequence(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_017", "type": "message", "role": "assistant", "id": "msg_e2e_017", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "1\n2\n3\n4\n"}}, "content": []map[string]any{{"type": "text", "text": "1\n2\n3\n4\n"}},
"model": "claude-opus-4-7", "stop_reason": "stop_sequence", "stop_sequence": "5", "model": "claude-opus-4-7", "stop_reason": "stop_sequence", "stop_sequence": "5",
"usage": map[string]any{"input_tokens": 22, "output_tokens": 10}, "usage": map[string]any{"input_tokens": 22, "output_tokens": 10},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -781,19 +787,20 @@ func TestE2E_Anthropic_NonStream_StopSequence(t *testing.T) {
func TestE2E_Anthropic_NonStream_MetadataUserID(t *testing.T) { func TestE2E_Anthropic_NonStream_MetadataUserID(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
metadata, _ := reqBody["metadata"].(map[string]any) metadata, _ := reqBody["metadata"].(map[string]any)
assert.Equal(t, "user_12345", metadata["user_id"]) assert.Equal(t, "user_12345", metadata["user_id"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_026", "type": "message", "role": "assistant", "id": "msg_e2e_026", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "你好!"}}, "content": []map[string]any{{"type": "text", "text": "你好!"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 12, "output_tokens": 5}, "usage": map[string]any{"input_tokens": 12, "output_tokens": 5},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -814,7 +821,7 @@ func TestE2E_Anthropic_NonStream_UsageWithCache(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_025", "type": "message", "role": "assistant", "id": "msg_e2e_025", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "你好!"}}, "content": []map[string]any{{"type": "text", "text": "你好!"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
@@ -822,7 +829,7 @@ func TestE2E_Anthropic_NonStream_UsageWithCache(t *testing.T) {
"input_tokens": 25, "output_tokens": 5, "input_tokens": 25, "output_tokens": 5,
"cache_creation_input_tokens": 15, "cache_read_input_tokens": 0, "cache_creation_input_tokens": 15, "cache_read_input_tokens": 0,
}, },
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -864,7 +871,8 @@ func TestE2E_Anthropic_Stream_Text(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
@@ -1353,21 +1361,22 @@ func TestE2E_Anthropic_NonStream_MultiToolUse(t *testing.T) {
func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) { func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
tc, _ := reqBody["tool_choice"].(map[string]any) tc, _ := reqBody["tool_choice"].(map[string]any)
assert.Equal(t, "any", tc["type"]) assert.Equal(t, "any", tc["type"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_tca", "type": "message", "role": "assistant", "id": "msg_e2e_tca", "type": "message", "role": "assistant",
"content": []map[string]any{ "content": []map[string]any{
{"type": "tool_use", "id": "toolu_tca_1", "name": "get_time", "input": map[string]any{"timezone": "Asia/Shanghai"}}, {"type": "tool_use", "id": "toolu_tca_1", "name": "get_time", "input": map[string]any{"timezone": "Asia/Shanghai"}},
}, },
"model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 100, "output_tokens": 30}, "usage": map[string]any{"input_tokens": 100, "output_tokens": 30},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -1397,20 +1406,21 @@ func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) {
func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) { func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
sys, ok := reqBody["system"].([]any) sys, ok := reqBody["system"].([]any)
require.True(t, ok, "system should be an array") require.True(t, ok, "system should be an array")
require.GreaterOrEqual(t, len(sys), 1) require.GreaterOrEqual(t, len(sys), 1)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_asys", "type": "message", "role": "assistant", "id": "msg_e2e_asys", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "已收到多条系统指令。"}}, "content": []map[string]any{{"type": "text", "text": "已收到多条系统指令。"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 50, "output_tokens": 10}, "usage": map[string]any{"input_tokens": 50, "output_tokens": 10},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -1433,21 +1443,22 @@ func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) {
func TestE2E_Anthropic_NonStream_ToolResultMessage(t *testing.T) { func TestE2E_Anthropic_NonStream_ToolResultMessage(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
msgs := reqBody["messages"].([]any) msgs := reqBody["messages"].([]any)
require.GreaterOrEqual(t, len(msgs), 3) require.GreaterOrEqual(t, len(msgs), 3)
lastMsg := msgs[len(msgs)-1].(map[string]any) lastMsg := msgs[len(msgs)-1].(map[string]any)
assert.Equal(t, "user", lastMsg["role"]) assert.Equal(t, "user", lastMsg["role"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_tr", "type": "message", "role": "assistant", "id": "msg_e2e_tr", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "北京当前晴天温度25°C。"}}, "content": []map[string]any{{"type": "text", "text": "北京当前晴天温度25°C。"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 150, "output_tokens": 20}, "usage": map[string]any{"input_tokens": 150, "output_tokens": 20},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -1497,7 +1508,8 @@ func TestE2E_Anthropic_Stream_ToolCalls(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
@@ -1659,9 +1671,10 @@ func TestE2E_CrossProtocol_StopReasonMapping(t *testing.T) {
func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) { func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
msgs := reqBody["messages"].([]any) msgs := reqBody["messages"].([]any)
require.GreaterOrEqual(t, len(msgs), 3) require.GreaterOrEqual(t, len(msgs), 3)
toolMsg := msgs[2].(map[string]any) toolMsg := msgs[2].(map[string]any)
@@ -1669,7 +1682,7 @@ func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) {
assert.Equal(t, "call_e2e_001", toolMsg["tool_call_id"]) assert.Equal(t, "call_e2e_001", toolMsg["tool_call_id"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-tr", "object": "chat.completion", "created": 1700000080, "model": "gpt-4o", "id": "chatcmpl-e2e-tr", "object": "chat.completion", "created": 1700000080, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
@@ -1678,7 +1691,7 @@ func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) {
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120}, "usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -1722,7 +1735,8 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream_ToolCalls(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
@@ -1879,7 +1893,8 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"正常\"}}\n\n", "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"正常\"}}\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
@@ -1902,5 +1917,7 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
assert.Contains(t, respBody, "正常") assert.Contains(t, respBody, "正常")
} }
var _ = fmt.Sprintf var (
var _ = time.Now _ = fmt.Sprintf
_ = time.Now
)

View File

@@ -7,16 +7,17 @@ import (
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"gorm.io/gorm"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/handler" "nex/backend/internal/handler"
"nex/backend/internal/handler/middleware" "nex/backend/internal/handler/middleware"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/internal/service" "nex/backend/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"gorm.io/gorm"
) )
func init() { func init() {
@@ -97,7 +98,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code) assert.Equal(t, 201, w.Code)
var createdModel domain.Model var createdModel domain.Model
json.Unmarshal(w.Body.Bytes(), &createdModel) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &createdModel))
assert.NotEmpty(t, createdModel.ID) assert.NotEmpty(t, createdModel.ID)
// 3. 列出 Provider // 3. 列出 Provider
@@ -106,7 +107,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var providers []domain.Provider var providers []domain.Provider
json.Unmarshal(w.Body.Bytes(), &providers) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &providers))
assert.Len(t, providers, 1) assert.Len(t, providers, 1)
assert.Equal(t, "sk-test-key", providers[0].APIKey) assert.Equal(t, "sk-test-key", providers[0].APIKey)
@@ -116,7 +117,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var models []domain.Model var models []domain.Model
json.Unmarshal(w.Body.Bytes(), &models) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &models))
assert.Len(t, models, 1) assert.Len(t, models, 1)
assert.Equal(t, "gpt-4", models[0].ModelName) assert.Equal(t, "gpt-4", models[0].ModelName)
@@ -163,7 +164,7 @@ func TestAnthropic_ModelCreation(t *testing.T) {
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code) assert.Equal(t, 201, w.Code)
var createdModel domain.Model var createdModel domain.Model
json.Unmarshal(w.Body.Bytes(), &createdModel) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &createdModel))
// 验证创建成功 // 验证创建成功
w = httptest.NewRecorder() w = httptest.NewRecorder()
@@ -194,9 +195,9 @@ func TestStats_RecordingAndQuery(t *testing.T) {
// 直接通过 repository 记录统计(模拟代理请求后的统计记录) // 直接通过 repository 记录统计(模拟代理请求后的统计记录)
statsRepo := repository.NewStatsRepository(db) statsRepo := repository.NewStatsRepository(db)
statsRepo.Record("p1", "gpt-4") require.NoError(t, statsRepo.Record("p1", "gpt-4"))
statsRepo.Record("p1", "gpt-4") require.NoError(t, statsRepo.Record("p1", "gpt-4"))
statsRepo.Record("p1", "gpt-4") require.NoError(t, statsRepo.Record("p1", "gpt-4"))
// 查询统计 // 查询统计
w = httptest.NewRecorder() w = httptest.NewRecorder()
@@ -205,7 +206,7 @@ func TestStats_RecordingAndQuery(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var stats []domain.UsageStats var stats []domain.UsageStats
json.Unmarshal(w.Body.Bytes(), &stats) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &stats))
assert.Len(t, stats, 1) assert.Len(t, stats, 1)
assert.Equal(t, 3, stats[0].RequestCount) assert.Equal(t, 3, stats[0].RequestCount)

View File

@@ -4,11 +4,11 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/config"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"nex/backend/internal/config"
) )
// setupTestDB 创建内存 SQLite 数据库并执行 AutoMigrate。 // setupTestDB 创建内存 SQLite 数据库并执行 AutoMigrate。

View File

@@ -3,9 +3,10 @@ package tests
import ( import (
"testing" "testing"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"nex/backend/internal/config"
) )
func TestMigration_ModelsUUIDPrimaryKey(t *testing.T) { func TestMigration_ModelsUUIDPrimaryKey(t *testing.T) {

View File

@@ -10,9 +10,10 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,9 +10,10 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -11,9 +11,10 @@ package mocks
import ( import (
context "context" context "context"
reflect "reflect"
conversion "nex/backend/internal/conversion" conversion "nex/backend/internal/conversion"
provider "nex/backend/internal/provider" provider "nex/backend/internal/provider"
reflect "reflect"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,9 +10,10 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,9 +10,10 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,9 +10,10 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,10 +10,11 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
time "time" time "time"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,10 +10,11 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
time "time" time "time"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

5
lefthook.yml Normal file
View File

@@ -0,0 +1,5 @@
pre-commit:
commands:
backend-lint:
glob: "backend/**/*.go"
run: cd backend && go tool golangci-lint run --new-from-rev HEAD ./...

View File

@@ -174,7 +174,7 @@
3. 设置默认值 3. 设置默认值
4. 绑定 CLI 参数 4. 绑定 CLI 参数
5. 绑定环境变量 5. 绑定环境变量
6. 读取配置文件 6. 读取配置文件(不存在时自动创建)
7. 反序列化到结构体 7. 反序列化到结构体
8. 验证配置 8. 验证配置
9. 打印配置摘要 9. 打印配置摘要

View File

@@ -31,16 +31,27 @@
- **THEN** SHALL 测试请求转换、响应转换、流式转换 - **THEN** SHALL 测试请求转换、响应转换、流式转换
- **THEN** SHALL 验证转换的准确性和完整性 - **THEN** SHALL 验证转换的准确性和完整性
#### Scenario: config 加载管道集成测试 #### Scenario: LoadConfigFromPath 默认值验证
- **WHEN** 运行 config 加载管道的集成测试 - **WHEN** 运行 config 加载管道的集成测试
- **THEN** SHALL 验证 LoadConfigFromPath 正确加载默认值 - **THEN** SHALL 验证 LoadConfigFromPath 正确加载默认值
- **THEN** SHALL 验证环境变量(`NEX_` 前缀)覆盖默认值
- **THEN** SHALL 验证 YAML 配置文件正确读取 - **THEN** SHALL 验证 YAML 配置文件正确读取
- **THEN** SHALL 验证优先级链CLI 参数 > 环境变量 > YAML 文件 > 默认值 - **THEN** SHALL 验证优先级链CLI 参数 > 环境变量 > YAML 文件 > 默认值
- **THEN** SHALL 验证首次启动自动创建配置文件 - **THEN** SHALL 验证首次启动自动创建配置文件
- **THEN** SHALL 验证 SaveConfig 后重新 LoadConfig 数据一致 - **THEN** SHALL 验证 SaveConfig 后重新 LoadConfig 数据一致
#### Scenario: 环境变量覆盖验证
- **WHEN** 设置 `NEX_SERVER_PORT=9000``NEX_LOG_LEVEL=debug`
- **THEN** SHALL 成功加载
- **THEN** 配置值 SHALL 反映环境变量覆盖
#### Scenario: 自动创建配置文件验证
- **WHEN** 调用 `LoadConfigFromPath` 并指向不存在的文件路径
- **THEN** SHALL 成功加载(不返回 `missing configuration for 'configPath'` 错误)
- **THEN** SHALL 返回默认配置对象
#### Scenario: handler 错误分支测试 #### Scenario: handler 错误分支测试
- **WHEN** 运行 handler 层的单元测试 - **WHEN** 运行 handler 层的单元测试