1
0

chore: 合并 dev-code-backend-format 到 master

This commit is contained in:
2026-04-24 18:20:12 +08:00
91 changed files with 1322 additions and 824 deletions

1
.gitignore vendored
View File

@@ -401,6 +401,7 @@ cython_debug/
# Custom # Custom
.claude .claude
.opencode .opencode
.codex
openspec/changes/archive openspec/changes/archive
temp temp
.agents .agents

View File

@@ -1,11 +1,11 @@
.PHONY: all dev build test lint clean \ .PHONY: all dev build test lint clean \
backend-build backend-run backend-dev backend-test backend-test-unit backend-test-integration backend-test-coverage \ backend-build backend-run backend-dev backend-test backend-test-all backend-test-unit backend-test-integration backend-test-coverage \
backend-lint backend-clean backend-deps backend-generate \ backend-lint backend-clean backend-deps backend-generate \
backend-db-up backend-db-down backend-db-status backend-db-create \ backend-db-up backend-db-down backend-db-status backend-db-create \
test-mysql-up test-mysql-down test-mysql test-mysql-quick \ test-mysql-up test-mysql-down test-mysql test-mysql-quick \
frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint frontend-clean \ frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint frontend-clean \
desktop-build desktop-build-mac desktop-build-win desktop-build-linux \ desktop-build desktop-build-mac desktop-build-win desktop-build-linux \
desktop-dev desktop-package-mac desktop-package-win desktop-package-linux desktop-clean \ desktop-dev desktop-test desktop-package-mac desktop-package-win desktop-package-linux desktop-clean \
desktop-prepare-frontend desktop-prepare-embedfs desktop-prepare-frontend desktop-prepare-embedfs
# ============================================ # ============================================
@@ -19,7 +19,7 @@ dev:
build: backend-build frontend-build build: backend-build frontend-build
@echo "✅ Build complete" @echo "✅ Build complete"
test: backend-test frontend-test test: backend-test desktop-test frontend-test
@echo "✅ All tests passed" @echo "✅ All tests passed"
lint: backend-lint frontend-lint lint: backend-lint frontend-lint
@@ -41,6 +41,9 @@ backend-dev:
cd backend && go run ./cmd/server cd backend && go run ./cmd/server
backend-test: backend-test:
cd backend && go test ./internal/... ./pkg/... ./tests/... ./cmd/server/... -v
backend-test-all:
cd backend && go test ./... -v cd backend && go test ./... -v
backend-test-unit: backend-test-unit:
@@ -179,6 +182,9 @@ desktop-dev: desktop-prepare-frontend desktop-prepare-embedfs
@echo "🖥️ Starting desktop app in dev mode..." @echo "🖥️ Starting desktop app in dev mode..."
cd backend && go run ./cmd/desktop cd backend && go run ./cmd/desktop
desktop-test:
cd backend && go test ./cmd/desktop/... -v
desktop-package-mac: desktop-package-mac:
./scripts/build/package-macos.sh ./scripts/build/package-macos.sh

View File

@@ -294,6 +294,9 @@ make frontend-test-coverage # 前端覆盖率
## 开发 ## 开发
```bash ```bash
# 首次克隆后安装 Git hooks
lefthook install
# 顶层便捷命令 # 顶层便捷命令
make dev # 启动开发环境(并行启动后端和前端) make dev # 启动开发环境(并行启动后端和前端)
make build # 构建所有产物 make build # 构建所有产物

91
backend/.golangci.yml Normal file
View File

@@ -0,0 +1,91 @@
run:
timeout: 5m
tests: true
linters:
disable-all: true
enable:
- forbidigo
- errorlint
- errcheck
- staticcheck
- revive
- gocritic
- gosec
- bodyclose
- noctx
- nilerr
- goimports
- gocyclo
linters-settings:
errcheck:
check-blank: true
check-type-assertions: true
exclude-functions:
- fmt.Fprintf
forbidigo:
analyze-types: true
forbid:
- p: '^fmt\.Print.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^fmt\.Fprint.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^log\.(Print|Println|Printf|Fatal|Fatalln|Fatalf|Panic|Panicln|Panicf)$'
msg: 使用 zap logger不要使用标准库 log
- p: '^zap\.L$'
msg: 通过依赖注入传递 *zap.Logger不要使用全局 logger
- p: '^zap\.S$'
msg: 不使用 Sugar logger
revive:
rules:
- name: exported
- name: var-naming
- name: indent-error-flow
- name: error-strings
- name: error-return
- name: blank-imports
- name: context-as-argument
- name: unexported-return
goimports:
local-prefixes: nex/backend
gocyclo:
min-complexity: 10
issues:
exclude-dirs:
- tests/mocks
exclude-generated: true
exclude-rules:
- path: '(_test\.go|tests/)'
linters:
- forbidigo
- path: '(_test\.go|tests/)'
linters:
- errcheck
source: '(^\s*_\s*=|,\s*_)'
- path: 'tests/integration/e2e_conversion_test\.go'
linters:
- errcheck
- path: '(_test\.go|tests/)'
linters:
- revive
text: '^exported:'
- path: '(_test\.go|tests/)'
linters:
- gosec
text: 'G(101|401|501)'
- path: '(_test\.go|tests/)'
linters:
- gocyclo
text: 'cyclomatic complexity (1[1-9]|20) of .* is high \(> 10\)'
- linters:
- revive
text: '(that stutters|BuildUrl should be BuildURL|ConvertHttpRequest should be ConvertHTTPRequest|ConvertHttpResponse should be ConvertHTTPResponse)'
- path: 'internal/conversion/.*\.go'
linters:
- gocyclo
- gocritic
- path: '(internal/provider/client\.go|internal/service/model_service_impl\.go|internal/service/stats_buffer\.go|internal/handler/proxy_handler\.go|cmd/(desktop|server)/main\.go)'
linters:
- gocyclo

View File

@@ -609,6 +609,7 @@ err := v.Validate(myStruct)
- **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节 - **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接 - **字符串拼接**:使用 `strings.Join`,不手写循环拼接
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配 - **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配lint 强约束errorlint 禁止 `err ==` 直接比较和 `err.(*T)` 直接断言)
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()` - **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
- **日志 error 字段**:使用 `zap.Error(err)`,不使用 `zap.String("error", err.Error())` 手工字符串化
- **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片 - **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片

View File

@@ -6,19 +6,25 @@ import (
"fmt" "fmt"
"os/exec" "os/exec"
"strings" "strings"
"go.uber.org/zap"
) )
func showError(title, message string) { func showError(title, message string) {
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`, script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`,
escapeAppleScript(message), escapeAppleScript(title)) escapeAppleScript(message), escapeAppleScript(title))
exec.Command("osascript", "-e", script).Run() if err := exec.Command("osascript", "-e", script).Run(); err != nil {
dialogLogger().Warn("显示错误对话框失败", zap.Error(err))
}
} }
func showAbout() { func showAbout() {
message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway" message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway"
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "关于 Nex Gateway"`, script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "关于 Nex Gateway"`,
escapeAppleScript(message)) escapeAppleScript(message))
exec.Command("osascript", "-e", script).Run() if err := exec.Command("osascript", "-e", script).Run(); err != nil {
dialogLogger().Warn("显示关于对话框失败", zap.Error(err))
}
} }
func escapeAppleScript(s string) string { func escapeAppleScript(s string) string {

View File

@@ -4,7 +4,6 @@ package main
import ( import (
"fmt" "fmt"
"os"
"os/exec" "os/exec"
"sync" "sync"
) )
@@ -63,7 +62,7 @@ func showError(title, message string) {
exec.Command("xmessage", "-center", exec.Command("xmessage", "-center",
fmt.Sprintf("%s: %s", title, message)).Run() fmt.Sprintf("%s: %s", title, message)).Run()
default: default:
fmt.Fprintf(os.Stderr, "错误: %s: %s\n", title, message) dialogLogger().Error("无法显示错误对话框")
} }
} }
@@ -83,6 +82,6 @@ func showAbout() {
exec.Command("xmessage", "-center", exec.Command("xmessage", "-center",
fmt.Sprintf("关于 Nex Gateway: %s", message)).Run() fmt.Sprintf("关于 Nex Gateway: %s", message)).Run()
default: default:
fmt.Fprintf(os.Stderr, "关于 Nex Gateway: %s\n", message) dialogLogger().Info("关于 Nex Gateway")
} }
} }

View File

@@ -0,0 +1,15 @@
package main
import (
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
)
func dialogLogger() *zap.Logger {
if zapLogger != nil {
return zapLogger
}
return pkgLogger.NewMinimal()
}

View File

@@ -13,10 +13,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/getlantern/systray" "nex/embedfs"
"github.com/gin-gonic/gin"
"github.com/gofrs/flock"
"go.uber.org/zap"
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
@@ -28,9 +25,13 @@ import (
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/internal/service" "nex/backend/internal/service"
pkgLogger "nex/backend/pkg/logger"
"nex/embedfs" "github.com/getlantern/systray"
"github.com/gin-gonic/gin"
"github.com/gofrs/flock"
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
) )
var ( var (
@@ -51,12 +52,16 @@ func main() {
showError("Nex Gateway", "已有 Nex 实例运行") showError("Nex Gateway", "已有 Nex 实例运行")
os.Exit(1) os.Exit(1)
} }
defer singleLock.Unlock() defer func() {
if err := singleLock.Unlock(); err != nil {
minimalLogger.Warn("释放实例锁失败", zap.Error(err))
}
}()
if err := checkPortAvailable(port); err != nil { if err := checkPortAvailable(port); err != nil {
minimalLogger.Error("端口不可用", zap.Error(err)) minimalLogger.Error("端口不可用", zap.Error(err))
showError("Nex Gateway", err.Error()) showError("Nex Gateway", err.Error())
os.Exit(1) return
} }
cfg, err := config.LoadConfig() cfg, err := config.LoadConfig()
@@ -75,7 +80,11 @@ func main() {
if err != nil { if err != nil {
minimalLogger.Fatal("初始化日志失败", zap.Error(err)) minimalLogger.Fatal("初始化日志失败", zap.Error(err))
} }
defer zapLogger.Sync() defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
cfg.PrintSummary(zapLogger) cfg.PrintSummary(zapLogger)
@@ -144,14 +153,14 @@ func main() {
go func() { go func() {
zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr)) zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr))
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error())) zapLogger.Fatal("服务器启动失败", zap.Error(err))
} }
}() }()
go func() { go func() {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil { if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
zapLogger.Warn("无法打开浏览器", zap.String("error", err.Error())) zapLogger.Warn("无法打开浏览器", zap.Error(err))
} }
}() }()
@@ -193,7 +202,7 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
func setupStaticFiles(r *gin.Engine) { func setupStaticFiles(r *gin.Engine) {
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist") distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
if err != nil { if err != nil {
zapLogger.Fatal("无法加载前端资源", zap.String("error", err.Error())) zapLogger.Fatal("无法加载前端资源", zap.Error(err))
} }
getContentType := func(path string) string { getContentType := func(path string) string {
@@ -266,7 +275,7 @@ func setupSystray(port int) {
icon, err = embedfs.Assets.ReadFile("assets/icon.png") icon, err = embedfs.Assets.ReadFile("assets/icon.png")
} }
if err != nil { if err != nil {
zapLogger.Error("无法加载托盘图标", zap.String("error", err.Error())) zapLogger.Error("无法加载托盘图标", zap.Error(err))
} }
systray.SetIcon(icon) systray.SetIcon(icon)
systray.SetTitle("Nex Gateway") systray.SetTitle("Nex Gateway")
@@ -287,7 +296,9 @@ func setupSystray(port int) {
for { for {
select { select {
case <-mOpen.ClickedCh: case <-mOpen.ClickedCh:
openBrowser(fmt.Sprintf("http://localhost:%d", port)) if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
zapLogger.Warn("打开浏览器失败", zap.Error(err))
}
case <-mAbout.ClickedCh: case <-mAbout.ClickedCh:
showAbout() showAbout()
case <-mQuit.ClickedCh: case <-mQuit.ClickedCh:
@@ -308,7 +319,9 @@ func doShutdown() {
if server != nil { if server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
server.Shutdown(ctx) if err := server.Shutdown(ctx); err != nil && zapLogger != nil {
zapLogger.Warn("关闭服务器失败", zap.Error(err))
}
} }
if shutdownCancel != nil { if shutdownCancel != nil {
@@ -346,8 +359,8 @@ func (s *SingletonLock) Lock() error {
return nil return nil
} }
func (s *SingletonLock) Unlock() { func (s *SingletonLock) Unlock() error {
s.flock.Unlock() return s.flock.Unlock()
} }
func openBrowser(url string) error { func openBrowser(url string) error {

View File

@@ -21,7 +21,7 @@ func TestCheckPortAvailable(t *testing.T) {
func TestCheckPortOccupied(t *testing.T) { func TestCheckPortOccupied(t *testing.T) {
port := 19827 port := 19827
listener, err := net.Listen("tcp", ":19827") listener, err := net.Listen("tcp", "127.0.0.1:19827")
if err != nil { if err != nil {
t.Fatalf("无法启动测试服务器: %v", err) t.Fatalf("无法启动测试服务器: %v", err)
} }
@@ -47,13 +47,19 @@ func TestCheckPortOccupied(t *testing.T) {
func TestCheckPortAvailableAfterClose(t *testing.T) { func TestCheckPortAvailableAfterClose(t *testing.T) {
port := 19828 port := 19828
listener, err := net.Listen("tcp", ":19828") listener, err := net.Listen("tcp", "127.0.0.1:19828")
if err != nil { if err != nil {
t.Fatalf("无法启动测试服务器: %v", err) t.Fatalf("无法启动测试服务器: %v", err)
} }
server := &http.Server{} server := &http.Server{ReadHeaderTimeout: time.Second}
go server.Serve(listener) defer server.Close()
go func() {
err := server.Serve(listener)
if err != nil && err != http.ErrServerClosed {
t.Errorf("serve failed: %v", err)
}
}()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)

View File

@@ -14,7 +14,11 @@ func TestSingletonLock_FirstLockSuccess(t *testing.T) {
if err := lock.Lock(); err != nil { if err := lock.Lock(); err != nil {
t.Fatalf("首次加锁应成功,但返回错误: %v", err) t.Fatalf("首次加锁应成功,但返回错误: %v", err)
} }
defer lock.Unlock() defer func() {
if err := lock.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
} }
func TestSingletonLock_DuplicateLockFails(t *testing.T) { func TestSingletonLock_DuplicateLockFails(t *testing.T) {
@@ -25,12 +29,18 @@ func TestSingletonLock_DuplicateLockFails(t *testing.T) {
if err := lock1.Lock(); err != nil { if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err) t.Fatalf("首次加锁应成功: %v", err)
} }
defer lock1.Unlock() defer func() {
if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
lock2 := NewSingletonLock(lockPath) lock2 := NewSingletonLock(lockPath)
err := lock2.Lock() err := lock2.Lock()
if err == nil { if err == nil {
lock2.Unlock() if unlockErr := lock2.Unlock(); unlockErr != nil {
t.Fatalf("解锁失败: %v", unlockErr)
}
t.Fatal("重复加锁应失败,但返回 nil") t.Fatal("重复加锁应失败,但返回 nil")
} }
} }
@@ -43,16 +53,22 @@ func TestSingletonLock_UnlockThenRelock(t *testing.T) {
if err := lock1.Lock(); err != nil { if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err) t.Fatalf("首次加锁应成功: %v", err)
} }
lock1.Unlock() if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
lock2 := NewSingletonLock(lockPath) lock2 := NewSingletonLock(lockPath)
if err := lock2.Lock(); err != nil { if err := lock2.Lock(); err != nil {
t.Fatalf("释放后重新加锁应成功: %v", err) t.Fatalf("释放后重新加锁应成功: %v", err)
} }
lock2.Unlock() if err := lock2.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
} }
func TestSingletonLock_UnlockWithoutLock(t *testing.T) { func TestSingletonLock_UnlockWithoutLock(t *testing.T) {
lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock")) lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock"))
lock.Unlock() if err := lock.Unlock(); err != nil {
t.Fatalf("未加锁时解锁失败: %v", err)
}
} }

View File

@@ -6,9 +6,9 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/gin-gonic/gin"
"nex/embedfs" "nex/embedfs"
"github.com/gin-gonic/gin"
) )
func TestSetupStaticFiles(t *testing.T) { func TestSetupStaticFiles(t *testing.T) {

View File

@@ -44,7 +44,11 @@ func main() {
if err != nil { if err != nil {
minimalLogger.Fatal("初始化日志失败", zap.Error(err)) minimalLogger.Fatal("初始化日志失败", zap.Error(err))
} }
defer zapLogger.Sync() defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
cfg.PrintSummary(zapLogger) cfg.PrintSummary(zapLogger)

View File

@@ -1,6 +1,7 @@
package config package config
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@@ -58,7 +59,10 @@ type LogConfig struct {
// DefaultConfig returns default config values // DefaultConfig returns default config values
func DefaultConfig() *Config { func DefaultConfig() *Config {
// Use home dir for default paths // Use home dir for default paths
homeDir, _ := os.UserHomeDir() homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex") nexDir := filepath.Join(homeDir, ".nex")
return &Config{ return &Config{
@@ -97,7 +101,7 @@ func GetConfigDir() (string, error) {
return "", err return "", err
} }
configDir := filepath.Join(homeDir, ".nex") configDir := filepath.Join(homeDir, ".nex")
if err := os.MkdirAll(configDir, 0755); err != nil { if err := os.MkdirAll(configDir, 0o755); err != nil {
return "", err return "", err
} }
return configDir, nil return configDir, nil
@@ -123,7 +127,10 @@ func GetConfigPath() (string, error) {
// setupDefaults 设置默认配置值 // setupDefaults 设置默认配置值
func setupDefaults(v *viper.Viper) { func setupDefaults(v *viper.Viper) {
homeDir, _ := os.UserHomeDir() homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex") nexDir := filepath.Join(homeDir, ".nex")
v.SetDefault("server.port", 9826) v.SetDefault("server.port", 9826)
@@ -177,27 +184,33 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
// 绑定所有 flag 到 viper // 绑定所有 flag 到 viper
// 注意:必须在设置默认值之后绑定 // 注意:必须在设置默认值之后绑定
v.BindPFlag("server.port", flagSet.Lookup("server-port")) bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
v.BindPFlag("server.read_timeout", flagSet.Lookup("server-read-timeout")) bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
v.BindPFlag("server.write_timeout", flagSet.Lookup("server-write-timeout")) bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
v.BindPFlag("database.driver", flagSet.Lookup("database-driver")) bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
v.BindPFlag("database.path", flagSet.Lookup("database-path")) bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
v.BindPFlag("database.host", flagSet.Lookup("database-host")) bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
v.BindPFlag("database.port", flagSet.Lookup("database-port")) bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
v.BindPFlag("database.user", flagSet.Lookup("database-user")) bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
v.BindPFlag("database.password", flagSet.Lookup("database-password")) bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
v.BindPFlag("database.dbname", flagSet.Lookup("database-dbname")) bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
v.BindPFlag("database.max_idle_conns", flagSet.Lookup("database-max-idle-conns")) bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
v.BindPFlag("database.max_open_conns", flagSet.Lookup("database-max-open-conns")) bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
v.BindPFlag("database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime")) bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
v.BindPFlag("log.level", flagSet.Lookup("log-level")) bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
v.BindPFlag("log.path", flagSet.Lookup("log-path")) bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
v.BindPFlag("log.max_size", flagSet.Lookup("log-max-size")) bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
v.BindPFlag("log.max_backups", flagSet.Lookup("log-max-backups")) bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
v.BindPFlag("log.max_age", flagSet.Lookup("log-max-age")) bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
v.BindPFlag("log.compress", flagSet.Lookup("log-compress")) bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
}
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
if err := v.BindPFlag(key, flag); err != nil {
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
}
} }
// setupEnv 绑定环境变量 // setupEnv 绑定环境变量
@@ -218,10 +231,17 @@ func setupConfigFile(v *viper.Viper, configPath string) error {
return appErrors.Wrap(appErrors.ErrInternal, err) return appErrors.Wrap(appErrors.ErrInternal, err)
} }
// 配置文件不存在,创建默认配置文件 // 配置文件不存在,创建默认配置文件
if err := v.SafeWriteConfig(); err != nil { writeErr := v.SafeWriteConfigAs(configPath)
// 忽略写入错误(可能目录已存在等) if writeErr == nil {
return nil return nil
} }
var alreadyExistsErr viper.ConfigFileAlreadyExistsError
if errors.As(writeErr, &alreadyExistsErr) {
return nil
}
return appErrors.Wrap(appErrors.ErrInternal, writeErr)
} }
return nil return nil
} }
@@ -246,7 +266,9 @@ func LoadConfigFromPath(configPath string) (*Config, error) {
setupFlags(v, flagSet) setupFlags(v, flagSet)
// 3. 解析 CLI 参数(忽略错误,因为可能没有参数) // 3. 解析 CLI 参数(忽略错误,因为可能没有参数)
flagSet.Parse(os.Args[1:]) if err := flagSet.Parse(os.Args[1:]); err != nil {
return nil, appErrors.Wrap(appErrors.ErrInvalidRequest, err)
}
// 4. 获取配置文件路径(可能被 --config 参数覆盖) // 4. 获取配置文件路径(可能被 --config 参数覆盖)
if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" { if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" {
@@ -295,11 +317,11 @@ func SaveConfig(cfg *Config) error {
// Ensure directory exists // Ensure directory exists
dir := filepath.Dir(configPath) dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil { if err := os.MkdirAll(dir, 0o755); err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err) return appErrors.Wrap(appErrors.ErrInternal, err)
} }
return os.WriteFile(configPath, data, 0600) return os.WriteFile(configPath, data, 0o600)
} }
// Validate validates the config // Validate validates the config

View File

@@ -236,7 +236,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
configPath := filepath.Join(dir, "config.yaml") configPath := filepath.Join(dir, "config.yaml")
data, err := yaml.Marshal(cfg) data, err := yaml.Marshal(cfg)
require.NoError(t, err) require.NoError(t, err)
err = os.WriteFile(configPath, data, 0644) err = os.WriteFile(configPath, data, 0o600)
require.NoError(t, err) require.NoError(t, err)
// 加载配置 // 加载配置

View File

@@ -6,15 +6,15 @@ import (
// Provider 供应商模型 // Provider 供应商模型
type Provider struct { type Provider struct {
ID string `gorm:"primaryKey" json:"id"` ID string `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"` Name string `gorm:"not null" json:"name"`
APIKey string `gorm:"not null" json:"api_key"` APIKey string `gorm:"not null" json:"api_key"`
BaseURL string `gorm:"not null" json:"base_url"` BaseURL string `gorm:"not null" json:"base_url"`
Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"` Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"`
Enabled bool `gorm:"default:true" json:"enabled"` Enabled bool `gorm:"default:true" json:"enabled"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"` Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
} }
// Model 模型配置id 为 UUID 自动生成UNIQUE(provider_id, model_name) // Model 模型配置id 为 UUID 自动生成UNIQUE(provider_id, model_name)
@@ -47,4 +47,3 @@ func (Model) TableName() string {
func (UsageStats) TableName() string { func (UsageStats) TableName() string {
return "usage_stats" return "usage_stats"
} }

View File

@@ -141,7 +141,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Message: err.Message, Message: err.Message,
}, },
} }
body, _ := json.Marshal(errMsg) body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"type":"error","error":{"type":"internal_error","message":"internal error"}}`), statusCode
}
return body, statusCode return body, statusCode
} }
@@ -235,7 +238,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
return "", nil, err return "", nil, err
} }
rewriteFunc := func(newModel string) ([]byte, error) { rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
} }
return current, rewriteFunc, nil return current, rewriteFunc, nil
@@ -269,7 +276,11 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
switch ifaceType { switch ifaceType {
case conversion.InterfaceTypeChat: case conversion.InterfaceTypeChat:
// Chat 响应必须有 model 字段,存在则改写,不存在则添加 // Chat 响应必须有 model 字段,存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
default: default:
return body, nil return body, nil

View File

@@ -2,6 +2,7 @@ package anthropic
import ( import (
"encoding/json" "encoding/json"
"errors"
"testing" "testing"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
@@ -52,10 +53,10 @@ func TestAdapter_BuildUrl(t *testing.T) {
a := NewAdapter() a := NewAdapter()
tests := []struct { tests := []struct {
name string name string
nativePath string nativePath string
interfaceType conversion.InterfaceType interfaceType conversion.InterfaceType
expected string expected string
}{ }{
{"聊天", "/v1/messages", conversion.InterfaceTypeChat, "/v1/messages"}, {"聊天", "/v1/messages", conversion.InterfaceTypeChat, "/v1/messages"},
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"}, {"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
@@ -102,9 +103,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
a := NewAdapter() a := NewAdapter()
tests := []struct { tests := []struct {
name string name string
interfaceType conversion.InterfaceType interfaceType conversion.InterfaceType
expected bool expected bool
}{ }{
{"聊天", conversion.InterfaceTypeChat, true}, {"聊天", conversion.InterfaceTypeChat, true},
{"模型", conversion.InterfaceTypeModels, true}, {"模型", conversion.InterfaceTypeModels, true},
@@ -141,8 +142,8 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
t.Run("解码嵌入请求", func(t *testing.T) { t.Run("解码嵌入请求", func(t *testing.T) {
_, err := a.DecodeEmbeddingRequest([]byte(`{}`)) _, err := a.DecodeEmbeddingRequest([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
@@ -150,24 +151,24 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3") provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider) _, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.True(t, errors.As(err, &convErr))
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("解码嵌入响应", func(t *testing.T) { t.Run("解码嵌入响应", func(t *testing.T) {
_, err := a.DecodeEmbeddingResponse([]byte(`{}`)) _, err := a.DecodeEmbeddingResponse([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("编码嵌入响应", func(t *testing.T) { t.Run("编码嵌入响应", func(t *testing.T) {
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{}) _, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
} }
@@ -178,8 +179,8 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
t.Run("解码重排序请求", func(t *testing.T) { t.Run("解码重排序请求", func(t *testing.T) {
_, err := a.DecodeRerankRequest([]byte(`{}`)) _, err := a.DecodeRerankRequest([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
@@ -187,24 +188,24 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3") provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider) _, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("解码重排序响应", func(t *testing.T) { t.Run("解码重排序响应", func(t *testing.T) {
_, err := a.DecodeRerankResponse([]byte(`{}`)) _, err := a.DecodeRerankResponse([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("编码重排序响应", func(t *testing.T) { t.Run("编码重排序响应", func(t *testing.T) {
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{}) _, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
} }

View File

@@ -28,7 +28,10 @@ func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
var canonicalMsgs []canonical.CanonicalMessage var canonicalMsgs []canonical.CanonicalMessage
for _, msg := range req.Messages { for _, msg := range req.Messages {
decoded := decodeMessage(msg) decoded, err := decodeMessage(msg)
if err != nil {
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析消息内容失败").WithCause(err)
}
canonicalMsgs = append(canonicalMsgs, decoded...) canonicalMsgs = append(canonicalMsgs, decoded...)
} }
@@ -94,10 +97,13 @@ func decodeSystem(system any) any {
} }
// decodeMessage 解码 Anthropic 消息 // decodeMessage 解码 Anthropic 消息
func decodeMessage(msg Message) []canonical.CanonicalMessage { func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
switch msg.Role { switch msg.Role {
case "user": case "user":
blocks := decodeContentBlocks(msg.Content) blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
var toolResults []canonical.ContentBlock var toolResults []canonical.ContentBlock
var others []canonical.ContentBlock var others []canonical.ContentBlock
for _, b := range blocks { for _, b := range blocks {
@@ -117,58 +123,83 @@ func decodeMessage(msg Message) []canonical.CanonicalMessage {
if len(result) == 0 { if len(result) == 0 {
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}}) result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
} }
return result return result, nil
case "assistant": case "assistant":
blocks := decodeContentBlocks(msg.Content) blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
if len(blocks) == 0 { if len(blocks) == 0 {
blocks = append(blocks, canonical.NewTextBlock("")) blocks = append(blocks, canonical.NewTextBlock(""))
} }
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}} return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
} }
return nil return nil, nil
} }
// decodeContentBlocks 解码内容块列表 // decodeContentBlocks 解码内容块列表
func decodeContentBlocks(content any) []canonical.ContentBlock { func decodeContentBlocks(content any) ([]canonical.ContentBlock, error) {
switch v := content.(type) { switch v := content.(type) {
case string: case string:
return []canonical.ContentBlock{canonical.NewTextBlock(v)} return []canonical.ContentBlock{canonical.NewTextBlock(v)}, nil
case []any: case []any:
var blocks []canonical.ContentBlock var blocks []canonical.ContentBlock
for _, item := range v { for _, item := range v {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
block := decodeSingleContentBlock(m) block, err := decodeSingleContentBlock(m)
if err != nil {
return nil, err
}
if block != nil { if block != nil {
blocks = append(blocks, *block) blocks = append(blocks, *block)
} }
} }
} }
if len(blocks) > 0 { if len(blocks) > 0 {
return blocks return blocks, nil
} }
return []canonical.ContentBlock{canonical.NewTextBlock("")} return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
case nil: case nil:
return []canonical.ContentBlock{canonical.NewTextBlock("")} return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
default: default:
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))} return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}, nil
} }
} }
// decodeSingleContentBlock 解码单个内容块 // decodeSingleContentBlock 解码单个内容块
func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock { func decodeSingleContentBlock(m map[string]any) (*canonical.ContentBlock, error) {
t, _ := m["type"].(string) t, ok := m["type"].(string)
if !ok {
return nil, nil
}
switch t { switch t {
case "text": case "text":
text, _ := m["text"].(string) text, ok := m["text"].(string)
return &canonical.ContentBlock{Type: "text", Text: text} if !ok {
text = ""
}
return &canonical.ContentBlock{Type: "text", Text: text}, nil
case "tool_use": case "tool_use":
id, _ := m["id"].(string) id, ok := m["id"].(string)
name, _ := m["name"].(string) if !ok {
input, _ := json.Marshal(m["input"]) id = ""
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input} }
name, ok := m["name"].(string)
if !ok {
name = ""
}
input, err := json.Marshal(m["input"])
if err != nil {
return nil, err
}
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}, nil
case "tool_result": case "tool_result":
toolUseID, _ := m["tool_use_id"].(string) toolUseID, ok := m["tool_use_id"].(string)
if !ok {
toolUseID = ""
}
isErr := false isErr := false
if ie, ok := m["is_error"].(bool); ok { if ie, ok := m["is_error"].(bool); ok {
isErr = ie isErr = ie
@@ -179,7 +210,11 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
case string: case string:
content = json.RawMessage(fmt.Sprintf("%q", cv)) content = json.RawMessage(fmt.Sprintf("%q", cv))
default: default:
content, _ = json.Marshal(cv) encoded, err := json.Marshal(cv)
if err != nil {
return nil, err
}
content = encoded
} }
} else { } else {
content = json.RawMessage(`""`) content = json.RawMessage(`""`)
@@ -189,15 +224,18 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
ToolUseID: toolUseID, ToolUseID: toolUseID,
Content: content, Content: content,
IsError: &isErr, IsError: &isErr,
} }, nil
case "thinking": case "thinking":
thinking, _ := m["thinking"].(string) thinking, ok := m["thinking"].(string)
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking} if !ok {
thinking = ""
}
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}, nil
case "redacted_thinking": case "redacted_thinking":
// 丢弃 // 丢弃
return nil return nil, nil
} }
return nil return nil, nil
} }
// decodeTools 解码工具定义 // decodeTools 解码工具定义
@@ -232,7 +270,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
} }
case map[string]any: case map[string]any:
t, _ := v["type"].(string) t, ok := v["type"].(string)
if !ok {
return nil
}
switch t { switch t {
case "auto": case "auto":
return canonical.NewToolChoiceAuto() return canonical.NewToolChoiceAuto()
@@ -241,7 +282,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
case "any": case "any":
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
case "tool": case "tool":
name, _ := v["name"].(string) name, ok := v["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name) return canonical.NewToolChoiceNamed(name)
} }
} }

View File

@@ -182,7 +182,7 @@ func encodeContentBlocks(blocks []canonical.ContentBlock) []map[string]any {
result = append(result, m) result = append(result, m)
case "tool_result": case "tool_result":
m := map[string]any{ m := map[string]any{
"type": "tool_result", "type": "tool_result",
"tool_use_id": b.ToolUseID, "tool_use_id": b.ToolUseID,
} }
if b.Content != nil { if b.Content != nil {
@@ -335,11 +335,11 @@ func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
} }
result := map[string]any{ result := map[string]any{
"id": resp.ID, "id": resp.ID,
"type": "message", "type": "message",
"role": "assistant", "role": "assistant",
"model": resp.Model, "model": resp.Model,
"content": blocks, "content": blocks,
"stop_reason": sr, "stop_reason": sr,
"stop_sequence": nil, "stop_sequence": nil,
"usage": usage, "usage": usage,

View File

@@ -33,7 +33,8 @@ func TestEncodeRequest_Basic(t *testing.T) {
assert.Equal(t, true, result["stream"]) assert.Equal(t, true, result["stream"])
assert.Equal(t, float64(1024), result["max_tokens"]) assert.Equal(t, float64(1024), result["max_tokens"])
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1) assert.Len(t, msgs, 1)
} }
@@ -55,17 +56,20 @@ func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
// tool 消息应被合并到相邻 user 消息 // tool 消息应被合并到相邻 user 消息
foundToolResult := false foundToolResult := false
for _, m := range msgs { for _, m := range msgs {
msgMap := m.(map[string]any) msgMap, ok := m.(map[string]any)
require.True(t, ok)
if msgMap["role"] == "user" { if msgMap["role"] == "user" {
content, ok := msgMap["content"].([]any) content, ok := msgMap["content"].([]any)
if ok { if ok {
for _, c := range content { for _, c := range content {
block := c.(map[string]any) block, ok := c.(map[string]any)
require.True(t, ok)
if block["type"] == "tool_result" { if block["type"] == "tool_result" {
foundToolResult = true foundToolResult = true
} }
@@ -93,8 +97,10 @@ func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
firstMsg := msgs[0].(map[string]any) require.True(t, ok)
firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", firstMsg["role"]) assert.Equal(t, "user", firstMsg["role"])
} }
@@ -140,9 +146,11 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "assistant", result["role"]) assert.Equal(t, "assistant", result["role"])
assert.Equal(t, "end_turn", result["stop_reason"]) assert.Equal(t, "end_turn", result["stop_reason"])
content := result["content"].([]any) content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1) assert.Len(t, content, 1)
block := content[0].(map[string]any) block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", block["type"]) assert.Equal(t, "text", block["type"])
assert.Equal(t, "你好", block["text"]) assert.Equal(t, "你好", block["text"])
} }
@@ -160,10 +168,12 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
data := result["data"].([]any) data, ok := result["data"].([]any)
require.True(t, ok)
assert.Len(t, data, 1) assert.Len(t, data, 1)
model := data[0].(map[string]any) model, ok := data[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "claude-3-opus", model["id"]) assert.Equal(t, "claude-3-opus", model["id"])
// created 应为 RFC3339 格式 // created 应为 RFC3339 格式
createdAt, ok := model["created_at"].(string) createdAt, ok := model["created_at"].(string)
@@ -280,11 +290,14 @@ func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1) assert.Len(t, msgs, 1)
userMsg := msgs[0].(map[string]any) userMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", userMsg["role"]) assert.Equal(t, "user", userMsg["role"])
content := userMsg["content"].([]any) content, ok := userMsg["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 2) assert.Len(t, content, 2)
} }
@@ -321,7 +334,8 @@ func TestEncodeResponse_ReasoningTokens(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any) usage, ok := result["usage"].(map[string]any)
require.True(t, ok)
_, hasReasoning := usage["reasoning_tokens"] _, hasReasoning := usage["reasoning_tokens"]
assert.False(t, hasReasoning) assert.False(t, hasReasoning)
} }
@@ -341,9 +355,11 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
content := result["content"].([]any) content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1) assert.Len(t, content, 1)
block := content[0].(map[string]any) block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", block["type"]) assert.Equal(t, "tool_use", block["type"])
assert.Equal(t, "tool_1", block["id"]) assert.Equal(t, "tool_1", block["id"])
assert.Equal(t, "search", block["name"]) assert.Equal(t, "search", block["name"])

View File

@@ -28,7 +28,7 @@ func NewStreamDecoder() *StreamDecoder {
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
data := rawChunk data := rawChunk
if len(d.utf8Remainder) > 0 { if len(d.utf8Remainder) > 0 {
data = append(d.utf8Remainder, rawChunk...) data = append(append([]byte{}, d.utf8Remainder...), rawChunk...)
d.utf8Remainder = nil d.utf8Remainder = nil
} }
@@ -50,9 +50,10 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
for _, line := range strings.Split(text, "\n") { for _, line := range strings.Split(text, "\n") {
line = strings.TrimRight(line, "\r") line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "event: ") { switch {
case strings.HasPrefix(line, "event: "):
eventType = strings.TrimPrefix(line, "event: ") eventType = strings.TrimPrefix(line, "event: ")
} else if strings.HasPrefix(line, "data: ") { case strings.HasPrefix(line, "data: "):
eventData = strings.TrimPrefix(line, "data: ") eventData = strings.TrimPrefix(line, "data: ")
if eventType != "" && eventData != "" { if eventType != "" && eventData != "" {
chunkEvents := d.processEvent(eventType, []byte(eventData)) chunkEvents := d.processEvent(eventType, []byte(eventData))
@@ -60,8 +61,8 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
} }
eventType = "" eventType = ""
eventData = "" eventData = ""
} else if line == "" { case line == "":
// SSE 事件分隔符 continue
} }
} }
@@ -135,7 +136,7 @@ func (d *StreamDecoder) processMessageStart(data []byte) []canonical.CanonicalSt
// processContentBlockStart 处理内容块开始事件 // processContentBlockStart 处理内容块开始事件
func (d *StreamDecoder) processContentBlockStart(data []byte) []canonical.CanonicalStreamEvent { func (d *StreamDecoder) processContentBlockStart(data []byte) []canonical.CanonicalStreamEvent {
var raw struct { var raw struct {
Index int `json:"index"` Index int `json:"index"`
ContentBlock struct { ContentBlock struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text"` Text string `json:"text"`

View File

@@ -47,23 +47,23 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
checkValue string checkValue string
}{ }{
{ {
name: "text_delta", name: "text_delta",
deltaType: "text_delta", deltaType: "text_delta",
deltaData: map[string]any{"type": "text_delta", "text": "你好"}, deltaData: map[string]any{"type": "text_delta", "text": "你好"},
checkField: "text", checkField: "text",
checkValue: "你好", checkValue: "你好",
}, },
{ {
name: "input_json_delta", name: "input_json_delta",
deltaType: "input_json_delta", deltaType: "input_json_delta",
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"}, deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
checkField: "partial_json", checkField: "partial_json",
checkValue: "{\"key\":", checkValue: "{\"key\":",
}, },
{ {
name: "thinking_delta", name: "thinking_delta",
deltaType: "thinking_delta", deltaType: "thinking_delta",
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"}, deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
checkField: "thinking", checkField: "thinking",
checkValue: "思考中", checkValue: "思考中",
}, },
@@ -74,7 +74,7 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
payload := map[string]any{ payload := map[string]any{
"type": "content_block_delta", "type": "content_block_delta",
"index": 0, "index": 0,
"delta": tt.deltaData, "delta": tt.deltaData,
} }
raw := makeAnthropicEvent("content_block_delta", payload) raw := makeAnthropicEvent("content_block_delta", payload)
@@ -298,7 +298,7 @@ func TestStreamDecoder_WebSearchToolResult_Suppressed(t *testing.T) {
"type": "content_block_start", "type": "content_block_start",
"index": 3, "index": 3,
"content_block": map[string]any{ "content_block": map[string]any{
"type": "web_search_tool_result", "type": "web_search_tool_result",
"tool_use_id": "search_1", "tool_use_id": "search_1",
}, },
} }
@@ -331,8 +331,8 @@ func TestStreamDecoder_CitationsDelta_Discarded(t *testing.T) {
"type": "content_block_delta", "type": "content_block_delta",
"index": 0, "index": 0,
"delta": map[string]any{ "delta": map[string]any{
"type": "citations_delta", "type": "citations_delta",
"citation": map[string]any{"title": "ref1"}, "citation": map[string]any{"title": "ref1"},
}, },
} }
raw := makeAnthropicEvent("content_block_delta", payload) raw := makeAnthropicEvent("content_block_delta", payload)
@@ -466,7 +466,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
}, },
} }
deltaPayload1 := map[string]any{ deltaPayload1 := map[string]any{
"type": "message_delta", "type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn"}, "delta": map[string]any{"stop_reason": "end_turn"},
"usage": map[string]any{"output_tokens": 25}, "usage": map[string]any{"output_tokens": 25},
} }
@@ -478,7 +478,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
assert.Equal(t, 25, events[0].Usage.OutputTokens) assert.Equal(t, 25, events[0].Usage.OutputTokens)
deltaPayload2 := map[string]any{ deltaPayload2 := map[string]any{
"type": "message_delta", "type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn"}, "delta": map[string]any{"stop_reason": "end_turn"},
"usage": map[string]any{"output_tokens": 30}, "usage": map[string]any{"output_tokens": 30},
} }

View File

@@ -80,7 +80,8 @@ func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
break break
} }
} }
cb := payload["content_block"].(map[string]any) cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", cb["type"]) assert.Equal(t, "text", cb["type"])
} }
@@ -107,7 +108,8 @@ func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
break break
} }
} }
cb := payload["content_block"].(map[string]any) cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", cb["type"]) assert.Equal(t, "tool_use", cb["type"])
assert.Equal(t, "toolu_1", cb["id"]) assert.Equal(t, "toolu_1", cb["id"])
assert.Equal(t, "search", cb["name"]) assert.Equal(t, "search", cb["name"])
@@ -131,7 +133,8 @@ func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
break break
} }
} }
cb := payload["content_block"].(map[string]any) cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "thinking", cb["type"]) assert.Equal(t, "thinking", cb["type"])
} }
@@ -173,7 +176,8 @@ func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
break break
} }
} }
delta := payload["delta"].(map[string]any) delta, okd := payload["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "end_turn", delta["stop_reason"]) assert.Equal(t, "end_turn", delta["stop_reason"])
} }
@@ -199,7 +203,8 @@ func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
break break
} }
} }
u := payload["usage"].(map[string]any) u, oku := payload["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(88), u["output_tokens"]) assert.Equal(t, float64(88), u["output_tokens"])
} }

View File

@@ -173,13 +173,15 @@ func TestDecodeMessage_UserWithOnlyToolResults(t *testing.T) {
} }
func TestDecodeContentBlocks_Nil(t *testing.T) { func TestDecodeContentBlocks_Nil(t *testing.T) {
blocks := decodeContentBlocks(nil) blocks, err := decodeContentBlocks(nil)
require.NoError(t, err)
assert.Len(t, blocks, 1) assert.Len(t, blocks, 1)
assert.Equal(t, "", blocks[0].Text) assert.Equal(t, "", blocks[0].Text)
} }
func TestDecodeContentBlocks_String(t *testing.T) { func TestDecodeContentBlocks_String(t *testing.T) {
blocks := decodeContentBlocks("hello") blocks, err := decodeContentBlocks("hello")
require.NoError(t, err)
assert.Len(t, blocks, 1) assert.Len(t, blocks, 1)
assert.Equal(t, "hello", blocks[0].Text) assert.Equal(t, "hello", blocks[0].Text)
} }
@@ -217,8 +219,10 @@ func TestEncodeToolChoice(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := encodeToolChoice(tt.choice) result := encodeToolChoice(tt.choice)
assert.Equal(t, tt.want["type"], result.(map[string]any)["type"]) r, ok := result.(map[string]any)
assert.Equal(t, tt.want["name"], result.(map[string]any)["name"]) require.True(t, ok)
assert.Equal(t, tt.want["type"], r["type"])
assert.Equal(t, tt.want["name"], r["name"])
}) })
} }
} }
@@ -315,12 +319,15 @@ func TestEncodeRequest_WithTools(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
tools := result["tools"].([]any) tools, okt := result["tools"].([]any)
require.True(t, okt)
assert.Len(t, tools, 1) assert.Len(t, tools, 1)
tool := tools[0].(map[string]any) tool, okt2 := tools[0].(map[string]any)
require.True(t, okt2)
assert.Equal(t, "search", tool["name"]) assert.Equal(t, "search", tool["name"])
assert.Equal(t, "Search things", tool["description"]) assert.Equal(t, "Search things", tool["description"])
tc := result["tool_choice"].(map[string]any) tc, oktc := result["tool_choice"].(map[string]any)
require.True(t, oktc)
assert.Equal(t, "auto", tc["type"]) assert.Equal(t, "auto", tc["type"])
} }
@@ -354,9 +361,9 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")}, Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
StopReason: &sr, StopReason: &sr,
Usage: canonical.CanonicalUsage{ Usage: canonical.CanonicalUsage{
InputTokens: 100, InputTokens: 100,
OutputTokens: 50, OutputTokens: 50,
CacheReadTokens: &cacheRead, CacheReadTokens: &cacheRead,
CacheCreationTokens: &cacheCreation, CacheCreationTokens: &cacheCreation,
}, },
} }
@@ -366,7 +373,8 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any) usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["input_tokens"]) assert.Equal(t, float64(100), usage["input_tokens"])
assert.Equal(t, float64(30), usage["cache_read_input_tokens"]) assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"]) assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])

View File

@@ -6,22 +6,22 @@ import (
// MessagesRequest Anthropic Messages 请求 // MessagesRequest Anthropic Messages 请求
type MessagesRequest struct { type MessagesRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
System any `json:"system,omitempty"` System any `json:"system,omitempty"`
MaxTokens int `json:"max_tokens"` MaxTokens int `json:"max_tokens"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"` TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"` Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
Metadata *RequestMetadata `json:"metadata,omitempty"` Metadata *RequestMetadata `json:"metadata,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"` Thinking *ThinkingConfig `json:"thinking,omitempty"`
OutputConfig *OutputConfig `json:"output_config,omitempty"` OutputConfig *OutputConfig `json:"output_config,omitempty"`
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"` DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
Container any `json:"container,omitempty"` Container any `json:"container,omitempty"`
} }
// RequestMetadata 请求元数据 // RequestMetadata 请求元数据
@@ -122,8 +122,8 @@ type ContentBlock struct {
// ResponseUsage 响应用量 // ResponseUsage 响应用量
type ResponseUsage struct { type ResponseUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"` CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"`
CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"` CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"`
} }

View File

@@ -38,8 +38,8 @@ type CanonicalEmbeddingResponse struct {
// EmbeddingData 嵌入数据项 // EmbeddingData 嵌入数据项
type EmbeddingData struct { type EmbeddingData struct {
Index int `json:"index"` Index int `json:"index"`
Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串 Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串
} }
// EmbeddingUsage 嵌入用量 // EmbeddingUsage 嵌入用量

View File

@@ -18,17 +18,17 @@ const (
type DeltaType string type DeltaType string
const ( const (
DeltaTypeText DeltaType = "text_delta" DeltaTypeText DeltaType = "text_delta"
DeltaTypeInputJSON DeltaType = "input_json_delta" DeltaTypeInputJSON DeltaType = "input_json_delta"
DeltaTypeThinking DeltaType = "thinking_delta" DeltaTypeThinking DeltaType = "thinking_delta"
) )
// StreamDelta 流式增量联合体 // StreamDelta 流式增量联合体
type StreamDelta struct { type StreamDelta struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"` PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"` Thinking string `json:"thinking,omitempty"`
} }
// StreamContentBlock 流式内容块联合体 // StreamContentBlock 流式内容块联合体
@@ -48,12 +48,12 @@ type CanonicalStreamEvent struct {
Message *StreamMessage `json:"message,omitempty"` Message *StreamMessage `json:"message,omitempty"`
// ContentBlockStartEvent / ContentBlockDeltaEvent / ContentBlockStopEvent // ContentBlockStartEvent / ContentBlockDeltaEvent / ContentBlockStopEvent
Index *int `json:"index,omitempty"` Index *int `json:"index,omitempty"`
ContentBlock *StreamContentBlock `json:"content_block,omitempty"` ContentBlock *StreamContentBlock `json:"content_block,omitempty"`
Delta *StreamDelta `json:"delta,omitempty"` Delta *StreamDelta `json:"delta,omitempty"`
// MessageDeltaEvent // MessageDeltaEvent
StopReason *StopReason `json:"stop_reason,omitempty"` StopReason *StopReason `json:"stop_reason,omitempty"`
Usage *CanonicalUsage `json:"usage,omitempty"` Usage *CanonicalUsage `json:"usage,omitempty"`
// ErrorEvent // ErrorEvent

View File

@@ -40,8 +40,8 @@ type ContentBlock struct {
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
// ToolUseBlock // ToolUseBlock
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Input json.RawMessage `json:"input,omitempty"` Input json.RawMessage `json:"input,omitempty"`
// ToolResultBlock // ToolResultBlock
@@ -138,43 +138,43 @@ type ThinkingConfig struct {
// OutputFormat 输出格式联合体 // OutputFormat 输出格式联合体
type OutputFormat struct { type OutputFormat struct {
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Schema json.RawMessage `json:"schema,omitempty"` Schema json.RawMessage `json:"schema,omitempty"`
Strict *bool `json:"strict,omitempty"` Strict *bool `json:"strict,omitempty"`
} }
// CanonicalRequest 规范请求 // CanonicalRequest 规范请求
type CanonicalRequest struct { type CanonicalRequest struct {
Model string `json:"model"` Model string `json:"model"`
System any `json:"system,omitempty"` // nil, string, or []SystemBlock System any `json:"system,omitempty"` // nil, string, or []SystemBlock
Messages []CanonicalMessage `json:"messages"` Messages []CanonicalMessage `json:"messages"`
Tools []CanonicalTool `json:"tools,omitempty"` Tools []CanonicalTool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"` ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Parameters RequestParameters `json:"parameters"` Parameters RequestParameters `json:"parameters"`
Thinking *ThinkingConfig `json:"thinking,omitempty"` Thinking *ThinkingConfig `json:"thinking,omitempty"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
UserID string `json:"user_id,omitempty"` UserID string `json:"user_id,omitempty"`
OutputFormat *OutputFormat `json:"output_format,omitempty"` OutputFormat *OutputFormat `json:"output_format,omitempty"`
ParallelToolUse *bool `json:"parallel_tool_use,omitempty"` ParallelToolUse *bool `json:"parallel_tool_use,omitempty"`
} }
// CanonicalUsage 规范用量 // CanonicalUsage 规范用量
type CanonicalUsage struct { type CanonicalUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
CacheReadTokens *int `json:"cache_read_tokens,omitempty"` CacheReadTokens *int `json:"cache_read_tokens,omitempty"`
CacheCreationTokens *int `json:"cache_creation_tokens,omitempty"` CacheCreationTokens *int `json:"cache_creation_tokens,omitempty"`
ReasoningTokens *int `json:"reasoning_tokens,omitempty"` ReasoningTokens *int `json:"reasoning_tokens,omitempty"`
} }
// CanonicalResponse 规范响应 // CanonicalResponse 规范响应
type CanonicalResponse struct { type CanonicalResponse struct {
ID string `json:"id"` ID string `json:"id"`
Model string `json:"model"` Model string `json:"model"`
Content []ContentBlock `json:"content"` Content []ContentBlock `json:"content"`
StopReason *StopReason `json:"stop_reason,omitempty"` StopReason *StopReason `json:"stop_reason,omitempty"`
Usage CanonicalUsage `json:"usage"` Usage CanonicalUsage `json:"usage"`
} }
// GetSystemString 获取系统消息字符串 // GetSystemString 获取系统消息字符串

View File

@@ -10,9 +10,9 @@ import (
func TestGetSystemString(t *testing.T) { func TestGetSystemString(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
system any system any
want string want string
}{ }{
{"string", "hello", "hello"}, {"string", "hello", "hello"},
{"nil", nil, ""}, {"nil", nil, ""},
@@ -97,11 +97,11 @@ func TestCanonicalRequest_RoundTrip(t *testing.T) {
func TestCanonicalResponse_RoundTrip(t *testing.T) { func TestCanonicalResponse_RoundTrip(t *testing.T) {
sr := StopReasonEndTurn sr := StopReasonEndTurn
resp := &CanonicalResponse{ resp := &CanonicalResponse{
ID: "resp-1", ID: "resp-1",
Model: "gpt-4", Model: "gpt-4",
Content: []ContentBlock{NewTextBlock("hello")}, Content: []ContentBlock{NewTextBlock("hello")},
StopReason: &sr, StopReason: &sr,
Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5}, Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5},
} }
data, err := json.Marshal(resp) data, err := json.Marshal(resp)

View File

@@ -114,7 +114,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
} }
interfaceType := clientAdapter.DetectInterfaceType(nativePath) interfaceType := clientAdapter.DetectInterfaceType(nativePath)
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType) providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
providerHeaders := providerAdapter.BuildHeaders(provider) providerHeaders := providerAdapter.BuildHeaders(provider)
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body) providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
if err != nil { if err != nil {
@@ -122,7 +122,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
} }
return &HTTPRequestSpec{ return &HTTPRequestSpec{
URL: provider.BaseURL + providerUrl, URL: provider.BaseURL + providerURL,
Method: spec.Method, Method: spec.Method,
Headers: providerHeaders, Headers: providerHeaders,
Body: providerBody, Body: providerBody,
@@ -134,24 +134,21 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
if e.IsPassthrough(clientProtocol, providerProtocol) { if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议时最小化改写 model 字段 // Smart Passthrough: 同协议时最小化改写 model 字段
if modelOverride != "" && len(spec.Body) > 0 { if modelOverride != "" && len(spec.Body) > 0 {
adapter, err := e.registry.Get(clientProtocol) adapter, getErr := e.registry.Get(clientProtocol)
if err != nil { if getErr == nil {
return &spec, nil rewrittenBody, rewriteErr := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
} if rewriteErr != nil {
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
if err != nil {
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体", e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
zap.Error(err), zap.Error(rewriteErr),
zap.String("interface", string(interfaceType))) zap.String("interface", string(interfaceType)))
return &spec, nil } else {
return &HTTPResponseSpec{
StatusCode: spec.StatusCode,
Headers: spec.Headers,
Body: rewrittenBody,
}, nil
}
} }
return &HTTPResponseSpec{
StatusCode: spec.StatusCode,
Headers: spec.Headers,
Body: rewrittenBody,
}, nil
} }
return &spec, nil return &spec, nil
} }
@@ -182,11 +179,10 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
if e.IsPassthrough(clientProtocol, providerProtocol) { if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段 // Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
if modelOverride != "" { if modelOverride != "" {
adapter, err := e.registry.Get(clientProtocol) adapter, getErr := e.registry.Get(clientProtocol)
if err != nil { if getErr == nil {
return NewPassthroughStreamConverter(), nil return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
} }
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
} }
return NewPassthroughStreamConverter(), nil return NewPassthroughStreamConverter(), nil
} }
@@ -201,9 +197,9 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
} }
ctx := ConversionContext{ ctx := ConversionContext{
ConversionID: uuid.New().String(), ConversionID: uuid.New().String(),
InterfaceType: InterfaceTypeChat, InterfaceType: InterfaceTypeChat,
Timestamp: time.Now(), Timestamp: time.Now(),
} }
return NewCanonicalStreamConverterWithMiddleware( return NewCanonicalStreamConverterWithMiddleware(
@@ -306,7 +302,7 @@ func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapte
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
models, err := providerAdapter.DecodeModelsResponse(body) models, err := providerAdapter.DecodeModelsResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
encoded, err := clientAdapter.EncodeModelsResponse(models) encoded, err := clientAdapter.EncodeModelsResponse(models)
@@ -320,12 +316,12 @@ func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdap
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
info, err := providerAdapter.DecodeModelInfoResponse(body) info, err := providerAdapter.DecodeModelInfoResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
encoded, err := clientAdapter.EncodeModelInfoResponse(info) encoded, err := clientAdapter.EncodeModelInfoResponse(info)
if err != nil { if err != nil {
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
return encoded, nil return encoded, nil
@@ -334,7 +330,7 @@ func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeEmbeddingRequest(body) req, err := clientAdapter.DecodeEmbeddingRequest(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error())) e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.Error(err))
return body, nil return body, nil
} }
return providerAdapter.EncodeEmbeddingRequest(req, provider) return providerAdapter.EncodeEmbeddingRequest(req, provider)
@@ -343,7 +339,7 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeEmbeddingResponse(body) resp, err := providerAdapter.DecodeEmbeddingResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
if modelOverride != "" { if modelOverride != "" {
@@ -355,21 +351,22 @@ func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeRerankRequest(body) req, err := clientAdapter.DecodeRerankRequest(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error())) e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.Error(err))
return body, nil return body, nil
} }
return providerAdapter.EncodeRerankRequest(req, provider) return providerAdapter.EncodeRerankRequest(req, provider)
} }
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeRerankResponse(body) resp, decodeErr := providerAdapter.DecodeRerankResponse(body)
if err != nil { if decodeErr == nil {
return body, nil if modelOverride != "" {
resp.Model = modelOverride
}
return clientAdapter.EncodeRerankResponse(resp)
} }
if modelOverride != "" {
resp.Model = modelOverride return body, nil
}
return clientAdapter.EncodeRerankResponse(resp)
} }
// DetectInterfaceType 检测接口类型 // DetectInterfaceType 检测接口类型
@@ -391,8 +388,12 @@ func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol stri
"type": "internal_error", "type": "internal_error",
}, },
} }
body, _ := json.Marshal(fallback) body, marshalErr := json.Marshal(fallback)
return body, 500, nil if marshalErr == nil {
return body, 500, nil
}
return []byte(`{"error":{"message":"internal error","type":"internal_error"}}`), 500, nil
} }
body, statusCode := adapter.EncodeError(err) body, statusCode := adapter.EncodeError(err)
return body, statusCode, nil return body, statusCode, nil

View File

@@ -38,8 +38,8 @@ func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
} }
} }
func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName } func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName }
func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" } func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" }
func (m *mockProtocolAdapter) SupportsPassthrough() bool { return m.passthrough } func (m *mockProtocolAdapter) SupportsPassthrough() bool { return m.passthrough }
func (m *mockProtocolAdapter) DetectInterfaceType(nativePath string) InterfaceType { func (m *mockProtocolAdapter) DetectInterfaceType(nativePath string) InterfaceType {
@@ -190,14 +190,16 @@ func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel str
// noopStreamDecoder 空流式解码器 // noopStreamDecoder 空流式解码器
type noopStreamDecoder struct{} type noopStreamDecoder struct{}
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil } func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil } return nil
}
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
// noopStreamEncoder 空流式编码器 // noopStreamEncoder 空流式编码器
type noopStreamEncoder struct{} type noopStreamEncoder struct{}
func (e *noopStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { return nil } func (e *noopStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { return nil }
func (e *noopStreamEncoder) Flush() [][]byte { return nil } func (e *noopStreamEncoder) Flush() [][]byte { return nil }
// ============ 测试用例 ============ // ============ 测试用例 ============
@@ -615,6 +617,7 @@ func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.Canonical
} }
return nil return nil
} }
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent { func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
if d.flushFn != nil { if d.flushFn != nil {
return d.flushFn() return d.flushFn()
@@ -634,6 +637,7 @@ func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEve
} }
return nil return nil
} }
func (e *engineTestStreamEncoder) Flush() [][]byte { func (e *engineTestStreamEncoder) Flush() [][]byte {
if e.flushFn != nil { if e.flushFn != nil {
return e.flushFn() return e.flushFn()

View File

@@ -6,17 +6,17 @@ import "fmt"
type ErrorCode string type ErrorCode string
const ( const (
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT" ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD" ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD"
ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE" ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE"
ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE" ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE"
ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR" ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR"
ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR" ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR"
ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR" ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR"
ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR" ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR"
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION" ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE" ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED" ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
) )
// ConversionError 协议转换错误 // ConversionError 协议转换错误

View File

@@ -4,10 +4,10 @@ package conversion
type InterfaceType string type InterfaceType string
const ( const (
InterfaceTypeChat InterfaceType = "CHAT" InterfaceTypeChat InterfaceType = "CHAT"
InterfaceTypeModels InterfaceType = "MODELS" InterfaceTypeModels InterfaceType = "MODELS"
InterfaceTypeModelInfo InterfaceType = "MODEL_INFO" InterfaceTypeModelInfo InterfaceType = "MODEL_INFO"
InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS" InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS"
InterfaceTypeRerank InterfaceType = "RERANK" InterfaceTypeRerank InterfaceType = "RERANK"
InterfaceTypePassthrough InterfaceType = "PASSTHROUGH" InterfaceTypePassthrough InterfaceType = "PASSTHROUGH"
) )

View File

@@ -138,7 +138,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Code: string(err.Code), Code: string(err.Code),
}, },
} }
body, _ := json.Marshal(errMsg) body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"error":{"message":"internal error","type":"internal_error","code":"INTERNAL_ERROR"}}`), statusCode
}
return body, statusCode return body, statusCode
} }
@@ -248,7 +251,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
return "", nil, err return "", nil, err
} }
rewriteFunc := func(newModel string) ([]byte, error) { rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
} }
return current, rewriteFunc, nil return current, rewriteFunc, nil
@@ -282,12 +289,20 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
switch ifaceType { switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings: case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加 // Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
case conversion.InterfaceTypeRerank: case conversion.InterfaceTypeRerank:
// Rerank 响应:存在 model 字段则改写,不存在则不添加 // Rerank 响应:存在 model 字段则改写,不存在则不添加
if _, exists := m["model"]; exists { if _, exists := m["model"]; exists {
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
} }
return json.Marshal(m) return json.Marshal(m)
default: default:

View File

@@ -48,10 +48,10 @@ func TestAdapter_BuildUrl(t *testing.T) {
a := NewAdapter() a := NewAdapter()
tests := []struct { tests := []struct {
name string name string
nativePath string nativePath string
interfaceType conversion.InterfaceType interfaceType conversion.InterfaceType
expected string expected string
}{ }{
{"聊天", "/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"}, {"聊天", "/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
{"模型", "/models", conversion.InterfaceTypeModels, "/models"}, {"模型", "/models", conversion.InterfaceTypeModels, "/models"},
@@ -92,9 +92,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
a := NewAdapter() a := NewAdapter()
tests := []struct { tests := []struct {
name string name string
interfaceType conversion.InterfaceType interfaceType conversion.InterfaceType
expected bool expected bool
}{ }{
{"聊天", conversion.InterfaceTypeChat, true}, {"聊天", conversion.InterfaceTypeChat, true},
{"模型", conversion.InterfaceTypeModels, true}, {"模型", conversion.InterfaceTypeModels, true},

View File

@@ -215,10 +215,16 @@ func decodeUserContent(content any) []canonical.ContentBlock {
var blocks []canonical.ContentBlock var blocks []canonical.ContentBlock
for _, item := range v { for _, item := range v {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string) t, ok := m["type"].(string)
if !ok {
continue
}
switch t { switch t {
case "text": case "text":
text, _ := m["text"].(string) text, ok := m["text"].(string)
if !ok {
text = ""
}
blocks = append(blocks, canonical.NewTextBlock(text)) blocks = append(blocks, canonical.NewTextBlock(text))
case "image_url": case "image_url":
blocks = append(blocks, canonical.ContentBlock{Type: "image"}) blocks = append(blocks, canonical.ContentBlock{Type: "image"})
@@ -242,9 +248,9 @@ func decodeUserContent(content any) []canonical.ContentBlock {
// contentPart 内容部分 // contentPart 内容部分
type contentPart struct { type contentPart struct {
Type string Type string
Text string Text string
Refusal string Refusal string
} }
// decodeContentParts 解码内容部分 // decodeContentParts 解码内容部分
@@ -256,13 +262,22 @@ func decodeContentParts(content any) []contentPart {
var result []contentPart var result []contentPart
for _, item := range parts { for _, item := range parts {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string) t, ok := m["type"].(string)
if !ok {
continue
}
switch t { switch t {
case "text": case "text":
text, _ := m["text"].(string) text, ok := m["text"].(string)
if !ok {
text = ""
}
result = append(result, contentPart{Type: "text", Text: text}) result = append(result, contentPart{Type: "text", Text: text})
case "refusal": case "refusal":
refusal, _ := m["refusal"].(string) refusal, ok := m["refusal"].(string)
if !ok {
refusal = ""
}
result = append(result, contentPart{Type: "refusal", Refusal: refusal}) result = append(result, contentPart{Type: "refusal", Refusal: refusal})
} }
} }
@@ -307,21 +322,33 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
} }
case map[string]any: case map[string]any:
t, _ := v["type"].(string) t, ok := v["type"].(string)
if !ok {
return nil
}
switch t { switch t {
case "function": case "function":
if fn, ok := v["function"].(map[string]any); ok { if fn, ok := v["function"].(map[string]any); ok {
name, _ := fn["name"].(string) name, ok := fn["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name) return canonical.NewToolChoiceNamed(name)
} }
case "custom": case "custom":
if custom, ok := v["custom"].(map[string]any); ok { if custom, ok := v["custom"].(map[string]any); ok {
name, _ := custom["name"].(string) name, ok := custom["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name) return canonical.NewToolChoiceNamed(name)
} }
case "allowed_tools": case "allowed_tools":
if at, ok := v["allowed_tools"].(map[string]any); ok { if at, ok := v["allowed_tools"].(map[string]any); ok {
mode, _ := at["mode"].(string) mode, ok := at["mode"].(string)
if !ok {
mode = ""
}
if mode == "required" { if mode == "required" {
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
} }
@@ -443,7 +470,7 @@ func decodeDeprecatedFields(req *ChatCompletionRequest) {
case map[string]any: case map[string]any:
if name, ok := v["name"].(string); ok { if name, ok := v["name"].(string); ok {
req.ToolChoice = map[string]any{ req.ToolChoice = map[string]any{
"type": "function", "type": "function",
"function": map[string]any{"name": name}, "function": map[string]any{"name": name},
} }
} }

View File

@@ -450,7 +450,7 @@ func encodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte
"object": "list", "object": "list",
"data": data, "data": data,
"model": resp.Model, "model": resp.Model,
"usage": resp.Usage, "usage": resp.Usage,
}) })
} }

View File

@@ -45,9 +45,11 @@ func TestEncodeRequest_SystemInjection(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 2) assert.Len(t, msgs, 2)
firstMsg := msgs[0].(map[string]any) firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "system", firstMsg["role"]) assert.Equal(t, "system", firstMsg["role"])
assert.Equal(t, "你是助手", firstMsg["content"]) assert.Equal(t, "你是助手", firstMsg["content"])
} }
@@ -72,12 +74,15 @@ func TestEncodeRequest_ToolCalls(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
assistantMsg := msgs[0].(map[string]any) require.True(t, ok)
assistantMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
toolCalls, ok := assistantMsg["tool_calls"].([]any) toolCalls, ok := assistantMsg["tool_calls"].([]any)
require.True(t, ok) require.True(t, ok)
assert.Len(t, toolCalls, 1) assert.Len(t, toolCalls, 1)
tc := toolCalls[0].(map[string]any) tc, ok := toolCalls[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "call_1", tc["id"]) assert.Equal(t, "call_1", tc["id"])
} }
@@ -100,11 +105,11 @@ func TestEncodeRequest_Thinking(t *testing.T) {
func TestEncodeResponse_Basic(t *testing.T) { func TestEncodeResponse_Basic(t *testing.T) {
sr := canonical.StopReasonEndTurn sr := canonical.StopReasonEndTurn
resp := &canonical.CanonicalResponse{ resp := &canonical.CanonicalResponse{
ID: "resp-1", ID: "resp-1",
Model: "gpt-4", Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")}, Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
StopReason: &sr, StopReason: &sr,
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5}, Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
} }
body, err := encodeResponse(resp) body, err := encodeResponse(resp)
@@ -115,9 +120,12 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "resp-1", result["id"]) assert.Equal(t, "resp-1", result["id"])
assert.Equal(t, "chat.completion", result["object"]) assert.Equal(t, "chat.completion", result["object"])
choices := result["choices"].([]any) choices, ok := result["choices"].([]any)
choice := choices[0].(map[string]any) require.True(t, ok)
msg := choice["message"].(map[string]any) choice, ok := choices[0].(map[string]any)
require.True(t, ok)
msg, ok := choice["message"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "你好", msg["content"]) assert.Equal(t, "你好", msg["content"])
assert.Equal(t, "stop", choice["finish_reason"]) assert.Equal(t, "stop", choice["finish_reason"])
} }
@@ -126,9 +134,9 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
sr := canonical.StopReasonToolUse sr := canonical.StopReasonToolUse
input := json.RawMessage(`{"q":"test"}`) input := json.RawMessage(`{"q":"test"}`)
resp := &canonical.CanonicalResponse{ resp := &canonical.CanonicalResponse{
ID: "resp-2", ID: "resp-2",
Model: "gpt-4", Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)}, Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
StopReason: &sr, StopReason: &sr,
} }
@@ -137,8 +145,12 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any) choices, okc := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any) require.True(t, okc)
msgMap, okm := choices[0].(map[string]any)
require.True(t, okm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
tcs, ok := msg["tool_calls"].([]any) tcs, ok := msg["tool_calls"].([]any)
require.True(t, ok) require.True(t, ok)
assert.Len(t, tcs, 1) assert.Len(t, tcs, 1)
@@ -158,7 +170,8 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "list", result["object"]) assert.Equal(t, "list", result["object"])
data := result["data"].([]any) data, okd := result["data"].([]any)
require.True(t, okd)
assert.Len(t, data, 2) assert.Len(t, data, 2)
} }
@@ -317,8 +330,12 @@ func TestEncodeResponse_Thinking(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any) choices, okch := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any) require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
assert.Equal(t, "回答", msg["content"]) assert.Equal(t, "回答", msg["content"])
assert.Equal(t, "思考过程", msg["reasoning_content"]) assert.Equal(t, "思考过程", msg["reasoning_content"])
} }

View File

@@ -18,9 +18,9 @@ func TestStreamDecoder_BasicText(t *testing.T) {
d := NewStreamDecoder() d := NewStreamDecoder()
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -56,8 +56,8 @@ func TestStreamDecoder_ToolCalls(t *testing.T) {
idx := 0 idx := 0
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -98,8 +98,8 @@ func TestStreamDecoder_Thinking(t *testing.T) {
d := NewStreamDecoder() d := NewStreamDecoder()
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -127,8 +127,8 @@ func TestStreamDecoder_FinishReason(t *testing.T) {
d := NewStreamDecoder() d := NewStreamDecoder()
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -161,8 +161,8 @@ func TestStreamDecoder_DoneSignal(t *testing.T) {
// 先发送一个文本 chunk // 先发送一个文本 chunk
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -190,8 +190,8 @@ func TestStreamDecoder_RefusalReuse(t *testing.T) {
// 连续两个 refusal delta chunk // 连续两个 refusal delta chunk
for _, text := range []string{"拒绝", "原因"} { for _, text := range []string{"拒绝", "原因"} {
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -250,8 +250,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
idx0 := 0 idx0 := 0
chunk1 := map[string]any{ chunk1 := map[string]any{
"id": "chatcmpl-mt", "id": "chatcmpl-mt",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -274,8 +274,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
idx1 := 1 idx1 := 1
chunk2 := map[string]any{ chunk2 := map[string]any{
"id": "chatcmpl-mt", "id": "chatcmpl-mt",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -322,8 +322,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
d := NewStreamDecoder() d := NewStreamDecoder()
chunk1 := map[string]any{ chunk1 := map[string]any{
"id": "chatcmpl-multi", "id": "chatcmpl-multi",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -332,8 +332,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
}, },
} }
chunk2 := map[string]any{ chunk2 := map[string]any{
"id": "chatcmpl-multi", "id": "chatcmpl-multi",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -358,8 +358,8 @@ func TestStreamDecoder_UTF8Truncation(t *testing.T) {
d := NewStreamDecoder() d := NewStreamDecoder()
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-utf8", "id": "chatcmpl-utf8",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -390,8 +390,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
idx := 0 idx := 0
chunk1 := map[string]any{ chunk1 := map[string]any{
"id": "chatcmpl-tc", "id": "chatcmpl-tc",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -412,8 +412,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
}, },
} }
chunk2 := map[string]any{ chunk2 := map[string]any{
"id": "chatcmpl-tc", "id": "chatcmpl-tc",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,

View File

@@ -10,9 +10,9 @@ import (
// StreamEncoder OpenAI 流式编码器 // StreamEncoder OpenAI 流式编码器
type StreamEncoder struct { type StreamEncoder struct {
bufferedStart *canonical.CanonicalStreamEvent bufferedStart *canonical.CanonicalStreamEvent
toolCallIndexMap map[string]int toolCallIndexMap map[string]int
nextToolCallIndex int nextToolCallIndex int
} }
// NewStreamEncoder 创建 OpenAI 流式编码器 // NewStreamEncoder 创建 OpenAI 流式编码器
@@ -195,8 +195,8 @@ func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent)
func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte { func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte {
chunk := map[string]any{ chunk := map[string]any{
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
"delta": delta, "delta": delta,
}}, }},
} }
return e.marshalChunk(chunk) return e.marshalChunk(chunk)

View File

@@ -27,8 +27,12 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
data := strings.TrimPrefix(s, "data: ") data := strings.TrimPrefix(s, "data: ")
data = strings.TrimRight(data, "\n") data = strings.TrimRight(data, "\n")
require.NoError(t, json.Unmarshal([]byte(data), &payload)) require.NoError(t, json.Unmarshal([]byte(data), &payload))
choices := payload["choices"].([]any) choices, okch := payload["choices"].([]any)
delta := choices[0].(map[string]any)["delta"].(map[string]any) require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
delta, okd := msgMap["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "assistant", delta["role"]) assert.Equal(t, "assistant", delta["role"])
} }

View File

@@ -177,7 +177,8 @@ func TestEncodeRerankResponse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "rerank-1", result["model"]) assert.Equal(t, "rerank-1", result["model"])
results := result["results"].([]any) results, okr := result["results"].([]any)
require.True(t, okr)
assert.Len(t, results, 1) assert.Len(t, results, 1)
} }
@@ -356,9 +357,9 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
reasoning := 20 reasoning := 20
sr := canonical.StopReasonEndTurn sr := canonical.StopReasonEndTurn
resp := &canonical.CanonicalResponse{ resp := &canonical.CanonicalResponse{
ID: "r1", ID: "r1",
Model: "gpt-4", Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")}, Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
StopReason: &sr, StopReason: &sr,
Usage: canonical.CanonicalUsage{ Usage: canonical.CanonicalUsage{
InputTokens: 100, InputTokens: 100,
@@ -373,7 +374,8 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any) usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["prompt_tokens"]) assert.Equal(t, float64(100), usage["prompt_tokens"])
ptd, ok := usage["prompt_tokens_details"].(map[string]any) ptd, ok := usage["prompt_tokens_details"].(map[string]any)
require.True(t, ok) require.True(t, ok)
@@ -412,8 +414,10 @@ func TestEncodeResponse_StopReasons(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any) choices, okch := result["choices"].([]any)
choice := choices[0].(map[string]any) require.True(t, okch)
choice, okc := choices[0].(map[string]any)
require.True(t, okc)
assert.Equal(t, tt.want, choice["finish_reason"]) assert.Equal(t, tt.want, choice["finish_reason"])
}) })
} }

View File

@@ -4,42 +4,42 @@ import "encoding/json"
// ChatCompletionRequest OpenAI Chat Completion 请求 // ChatCompletionRequest OpenAI Chat Completion 请求
type ChatCompletionRequest struct { type ChatCompletionRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
Tools []Tool `json:"tools,omitempty"` Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"` MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"` PresencePenalty *float64 `json:"presence_penalty,omitempty"`
Stop any `json:"stop,omitempty"` Stop any `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` ReasoningEffort string `json:"reasoning_effort,omitempty"`
N *int `json:"n,omitempty"` N *int `json:"n,omitempty"`
Seed *int `json:"seed,omitempty"` Seed *int `json:"seed,omitempty"`
Logprobs *bool `json:"logprobs,omitempty"` Logprobs *bool `json:"logprobs,omitempty"`
TopLogprobs *int `json:"top_logprobs,omitempty"` TopLogprobs *int `json:"top_logprobs,omitempty"`
// 已废弃字段 // 已废弃字段
Functions []FunctionDef `json:"functions,omitempty"` Functions []FunctionDef `json:"functions,omitempty"`
FunctionCall any `json:"function_call,omitempty"` FunctionCall any `json:"function_call,omitempty"`
} }
// Message OpenAI 消息 // Message OpenAI 消息
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content any `json:"content"` Content any `json:"content"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"`
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"` ReasoningContent string `json:"reasoning_content,omitempty"`
// 已废弃 // 已废弃
FunctionCall *FunctionCallMsg `json:"function_call,omitempty"` FunctionCall *FunctionCallMsg `json:"function_call,omitempty"`
@@ -88,8 +88,8 @@ type FunctionDef struct {
// ResponseFormat OpenAI 响应格式 // ResponseFormat OpenAI 响应格式
type ResponseFormat struct { type ResponseFormat struct {
Type string `json:"type"` Type string `json:"type"`
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"` JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
} }
// JSONSchemaDef JSON Schema 定义 // JSONSchemaDef JSON Schema 定义
@@ -118,7 +118,7 @@ type ChatCompletionResponse struct {
// Choice OpenAI 选择项 // Choice OpenAI 选择项
type Choice struct { type Choice struct {
Index int `json:"index"` Index int `json:"index"`
Message *Message `json:"message,omitempty"` Message *Message `json:"message,omitempty"`
Delta *Message `json:"delta,omitempty"` Delta *Message `json:"delta,omitempty"`
FinishReason *string `json:"finish_reason"` FinishReason *string `json:"finish_reason"`
@@ -127,10 +127,10 @@ type Choice struct {
// Usage OpenAI 用量 // Usage OpenAI 用量
type Usage struct { type Usage struct {
PromptTokens int `json:"prompt_tokens"` PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"` CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"` PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
} }

View File

@@ -61,7 +61,7 @@ func initDB(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error)
return gorm.Open(mysql.Open(dsn), gormConfig) return gorm.Open(mysql.Open(dsn), gormConfig)
default: default:
dbDir := filepath.Dir(cfg.Path) dbDir := filepath.Dir(cfg.Path)
if err := os.MkdirAll(dbDir, 0755); err != nil { if err := os.MkdirAll(dbDir, 0o755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err) return nil, fmt.Errorf("创建数据库目录失败: %w", err)
} }
if zapLogger != nil { if zapLogger != nil {
@@ -95,7 +95,9 @@ func runMigrations(db *gorm.DB, driver string, zapLogger *zap.Logger) error {
zap.String("dir", migrationsSubDir)) zap.String("dir", migrationsSubDir))
} }
goose.SetDialect(gooseDialect) if err := goose.SetDialect(gooseDialect); err != nil {
return err
}
if err := goose.Up(sqlDB, migrationsDir); err != nil { if err := goose.Up(sqlDB, migrationsDir); err != nil {
return err return err
} }

View File

@@ -4,11 +4,11 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/config"
) )
func TestInit_SQLite(t *testing.T) { func TestInit_SQLite(t *testing.T) {

View File

@@ -13,4 +13,3 @@ type Provider struct {
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
} }

View File

@@ -6,13 +6,13 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
) )
func TestProviderHandler_CreateProvider_Success(t *testing.T) { func TestProviderHandler_CreateProvider_Success(t *testing.T) {
@@ -24,9 +24,9 @@ func TestProviderHandler_CreateProvider_Success(t *testing.T) {
h := NewProviderHandler(mockSvc) h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{ body, _ := json.Marshal(map[string]string{
"id": "p1", "id": "p1",
"name": "Test", "name": "Test",
"api_key": "sk-test", "api_key": "sk-test",
"base_url": "https://api.test.com", "base_url": "https://api.test.com",
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()

View File

@@ -9,23 +9,22 @@ import (
"strings" "strings"
"testing" "testing"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"gorm.io/gorm" "gorm.io/gorm"
"nex/backend/internal/domain"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks"
) )
func init() { func init() {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
} }
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) { func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()

View File

@@ -20,7 +20,6 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
if id, ok := requestID.(string); ok { if id, ok := requestID.(string); ok {
requestIDStr = id requestIDStr = id
} }
logger.Info("请求开始", logger.Info("请求开始",
pkglogger.Method(c.Request.Method), pkglogger.Method(c.Request.Method),
pkglogger.Path(path), pkglogger.Path(path),

View File

@@ -4,13 +4,13 @@ import (
"errors" "errors"
"net/http" "net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
) )
// ModelHandler 模型管理处理器 // ModelHandler 模型管理处理器
@@ -58,16 +58,16 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
err := h.modelService.Create(model) err := h.modelService.Create(model)
if err != nil { if err != nil {
if err == appErrors.ErrProviderNotFound { if errors.Is(err, appErrors.ErrProviderNotFound) {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在", "error": "供应商不存在",
}) })
return return
} }
if err == appErrors.ErrDuplicateModel { if errors.Is(err, appErrors.ErrDuplicateModel) {
c.JSON(http.StatusConflict, gin.H{ c.JSON(http.StatusConflict, gin.H{
"error": "同一供应商下模型名称已存在", "error": "同一供应商下模型名称已存在",
"code": appErrors.ErrDuplicateModel.Code, "code": appErrors.ErrDuplicateModel.Code,
}) })
return return
} }
@@ -101,7 +101,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
model, err := h.modelService.Get(id) model, err := h.modelService.Get(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到", "error": "模型未找到",
}) })
@@ -166,7 +166,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
err := h.modelService.Delete(id) err := h.modelService.Delete(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到", "error": "模型未找到",
}) })

View File

@@ -4,13 +4,13 @@ import (
"errors" "errors"
"net/http" "net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
) )
// ProviderHandler 供应商管理处理器 // ProviderHandler 供应商管理处理器
@@ -55,7 +55,7 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
err := h.providerService.Create(provider) err := h.providerService.Create(provider)
if err != nil { if err != nil {
if err == appErrors.ErrInvalidProviderID { if errors.Is(err, appErrors.ErrInvalidProviderID) {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"error": appErrors.ErrInvalidProviderID.Message, "error": appErrors.ErrInvalidProviderID.Message,
"code": appErrors.ErrInvalidProviderID.Code, "code": appErrors.ErrInvalidProviderID.Code,
@@ -86,7 +86,7 @@ func (h *ProviderHandler) GetProvider(c *gin.Context) {
provider, err := h.providerService.Get(id) provider, err := h.providerService.Get(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到", "error": "供应商未找到",
}) })
@@ -113,7 +113,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
err := h.providerService.Update(id, req) err := h.providerService.Update(id, req)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到", "error": "供应商未找到",
}) })
@@ -145,7 +145,7 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
err := h.providerService.Delete(id) err := h.providerService.Delete(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到", "error": "供应商未找到",
}) })

View File

@@ -3,30 +3,32 @@ package handler
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"errors"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical" "nex/backend/internal/conversion/canonical"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/service" "nex/backend/internal/service"
"nex/backend/pkg/modelid" "nex/backend/pkg/modelid"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger" pkglogger "nex/backend/pkg/logger"
) )
// ProxyHandler 统一代理处理器 // ProxyHandler 统一代理处理器
type ProxyHandler struct { type ProxyHandler struct {
engine *conversion.ConversionEngine engine *conversion.ConversionEngine
client provider.ProviderClient client provider.ProviderClient
routingService service.RoutingService routingService service.RoutingService
providerService service.ProviderService providerService service.ProviderService
statsService service.StatsService statsService service.StatsService
logger *zap.Logger logger *zap.Logger
} }
// NewProxyHandler 创建统一代理处理器 // NewProxyHandler 创建统一代理处理器
@@ -138,7 +140,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
targetProvider := conversion.NewTargetProvider( targetProvider := conversion.NewTargetProvider(
routeResult.Provider.BaseURL, routeResult.Provider.BaseURL,
routeResult.Provider.APIKey, routeResult.Provider.APIKey,
routeResult.Model.ModelName, // 上游模型名,用于请求改写 routeResult.Model.ModelName, // 上游模型名,用于请求改写
) )
// 判断是否流式 // 判断是否流式
@@ -159,7 +161,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 转换请求 // 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil { if err != nil {
h.logger.Error("转换请求失败", zap.String("error", err.Error())) h.logger.Error("转换请求失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
} }
@@ -167,7 +169,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 发送请求 // 发送请求
resp, err := h.client.Send(c.Request.Context(), *outSpec) resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil { if err != nil {
h.logger.Error("发送请求失败", zap.String("error", err.Error())) h.logger.Error("发送请求失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
} }
@@ -175,7 +177,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 转换响应,传入 modelOverride跨协议场景覆写 model 字段) // 转换响应,传入 modelOverride跨协议场景覆写 model 字段)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID) convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
if err != nil { if err != nil {
h.logger.Error("转换响应失败", zap.String("error", err.Error())) h.logger.Error("转换响应失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
} }
@@ -191,7 +193,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body) c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
go func() { go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}() }()
} }
@@ -226,34 +228,50 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
for event := range eventChan { for event := range eventChan {
if event.Error != nil { if event.Error != nil {
h.logger.Error("流读取错误", zap.String("error", event.Error.Error())) h.logger.Error("流读取错误", zap.Error(event.Error))
break break
} }
if event.Done { if event.Done {
// flush 转换器 // flush 转换器
chunks := streamConverter.Flush() chunks := streamConverter.Flush()
for _, chunk := range chunks { if err := h.writeStreamChunks(writer, chunks); err != nil {
writer.Write(chunk) h.logger.Warn("流式响应写回失败", zap.Error(err))
writer.Flush()
} }
break break
} }
chunks := streamConverter.ProcessChunk(event.Data) chunks := streamConverter.ProcessChunk(event.Data)
for _, chunk := range chunks { if err := h.writeStreamChunks(writer, chunks); err != nil {
writer.Write(chunk) h.logger.Warn("流式响应写回失败", zap.Error(err))
writer.Flush() break
} }
} }
go func() { go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}() }()
} }
func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error {
for _, chunk := range chunks {
if _, err := writer.Write(chunk); err != nil {
return err
}
if err := writer.Flush(); err != nil {
return err
}
}
return nil
}
// isStreamRequest 判断是否流式请求 // isStreamRequest 判断是否流式请求
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool { func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol) ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol)
if err != nil {
return false
}
if ifaceType != conversion.InterfaceTypeChat { if ifaceType != conversion.InterfaceTypeChat {
return false return false
} }
@@ -272,7 +290,7 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
// 从数据库查询所有启用的模型 // 从数据库查询所有启用的模型
models, err := h.providerService.ListEnabledModels() models, err := h.providerService.ListEnabledModels()
if err != nil { if err != nil {
h.logger.Error("查询启用模型失败", zap.String("error", err.Error())) h.logger.Error("查询启用模型失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
return return
} }
@@ -294,7 +312,7 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
// 使用 adapter 编码返回 // 使用 adapter 编码返回
body, err := adapter.EncodeModelsResponse(modelList) body, err := adapter.EncodeModelsResponse(modelList)
if err != nil { if err != nil {
h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error())) h.logger.Error("编码 Models 响应失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
return return
} }
@@ -342,8 +360,13 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
// writeConversionError 写入转换错误 // writeConversionError 写入转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) { func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
if convErr, ok := err.(*conversion.ConversionError); ok { var convErr *conversion.ConversionError
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol) if errors.As(err, &convErr) {
body, statusCode, encodeErr := h.engine.EncodeError(convErr, clientProtocol)
if encodeErr != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": encodeErr.Error()})
return
}
c.Data(statusCode, "application/json", body) c.Data(statusCode, "application/json", body)
return return
} }

View File

@@ -8,27 +8,26 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks"
) )
func init() { func init() {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
} }
func setupProxyEngine(t *testing.T) *conversion.ConversionEngine { func setupProxyEngine(t *testing.T) *conversion.ConversionEngine {
t.Helper() t.Helper()
registry := conversion.NewMemoryRegistry() registry := conversion.NewMemoryRegistry()
@@ -844,7 +843,8 @@ func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
require.True(t, ok) require.True(t, ok)
assert.Len(t, data, 2) assert.Len(t, data, 2)
first := data[0].(map[string]interface{}) first, ok2 := data[0].(map[string]interface{})
require.True(t, ok2)
assert.Equal(t, "openai/gpt-4", first["id"]) assert.Equal(t, "openai/gpt-4", first["id"])
} }
@@ -918,7 +918,7 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) { client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
var req map[string]interface{} var req map[string]interface{}
json.Unmarshal(spec.Body, &req) require.NoError(t, json.Unmarshal(spec.Body, &req))
assert.Equal(t, "gpt-4", req["model"]) assert.Equal(t, "gpt-4", req["model"])
return &conversion.HTTPResponseSpec{ return &conversion.HTTPResponseSpec{

View File

@@ -5,9 +5,9 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/gin-gonic/gin"
"nex/backend/internal/service" "nex/backend/internal/service"
"github.com/gin-gonic/gin"
) )
// StatsHandler 统计处理器 // StatsHandler 统计处理器

View File

@@ -51,6 +51,7 @@ type Client struct {
} }
// ProviderClient 供应商客户端接口 // ProviderClient 供应商客户端接口
//
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks //go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
type ProviderClient interface { type ProviderClient interface {
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
@@ -141,7 +142,10 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
defer resp.Body.Close() defer resp.Body.Close()
cancel() cancel()
errBody, _ := io.ReadAll(resp.Body) errBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, fmt.Errorf("供应商返回错误: HTTP %d读取错误响应失败: %w", resp.StatusCode, readErr)
}
if len(errBody) > 0 { if len(errBody) > 0 {
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody)) return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody))
} }
@@ -184,7 +188,7 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
if isNetworkError(err) { if isNetworkError(err) {
c.logger.Error("流网络错误", zap.String("error", err.Error())) c.logger.Error("流网络错误", zap.Error(err))
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)} eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
} else { } else {
c.logger.Error("流读取错误", zap.Error(err)) c.logger.Error("流读取错误", zap.Error(err))

View File

@@ -41,7 +41,8 @@ func TestClient_Send_Success(t *testing.T) {
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"test","model":"gpt-4"}`)) _, err := w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
@@ -65,7 +66,8 @@ func TestClient_Send_Success(t *testing.T) {
func TestClient_Send_ErrorResponse(t *testing.T) { func TestClient_Send_ErrorResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":{"message":"Invalid API key"}}`)) _, err := w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
@@ -140,12 +142,15 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n")) _, err = w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n")) _, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
})) }))
@@ -165,11 +170,12 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
var dataEvents [][]byte var dataEvents [][]byte
var doneEvents int var doneEvents int
for event := range eventChan { for event := range eventChan {
if event.Done { switch {
case event.Done:
doneEvents++ doneEvents++
} else if event.Error != nil { case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error) t.Fatalf("unexpected error: %v", event.Error)
} else { default:
dataEvents = append(dataEvents, event.Data) dataEvents = append(dataEvents, event.Data)
} }
} }
@@ -215,7 +221,8 @@ func TestClient_Send_EmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method) assert.Equal(t, "GET", r.Method)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"result":"ok"}`)) _, err := w.Write([]byte(`{"result":"ok"}`))
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
@@ -238,10 +245,12 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n")) _, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
})) }))
@@ -261,11 +270,12 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
var dataCount int var dataCount int
var doneCount int var doneCount int
for event := range eventChan { for event := range eventChan {
if event.Done { switch {
case event.Done:
doneCount++ doneCount++
} else if event.Error != nil { case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error) t.Fatalf("unexpected error: %v", event.Error)
} else { default:
dataCount++ dataCount++
} }
} }
@@ -279,10 +289,12 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n")) _, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
})) }))
@@ -364,13 +376,14 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
if hijacker, ok := w.(http.Hijacker); ok { if hijacker, ok := w.(http.Hijacker); ok {
conn, _, _ := hijacker.Hijack() conn, _, _ := hijacker.Hijack()
if conn != nil { if conn != nil {
conn.Close() require.NoError(t, conn.Close())
} }
} }
})) }))

View File

@@ -3,10 +3,11 @@ package repository
import ( import (
"time" "time"
"gorm.io/gorm"
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
) )

View File

@@ -3,13 +3,13 @@ package repository
import ( import (
"testing" "testing"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/gorm" "gorm.io/gorm"
testHelpers "nex/backend/tests" testHelpers "nex/backend/tests"
"nex/backend/internal/domain"
) )
func setupTestDB(t *testing.T) *gorm.DB { func setupTestDB(t *testing.T) *gorm.DB {

View File

@@ -3,11 +3,11 @@ package repository
import ( import (
"time" "time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type statsRepository struct { type statsRepository struct {
@@ -19,8 +19,8 @@ func NewStatsRepository(db *gorm.DB) StatsRepository {
} }
func (r *statsRepository) Record(providerID, modelName string) error { func (r *statsRepository) Record(providerID, modelName string) error {
today := time.Now().Format("2006-01-02") now := time.Now()
todayTime, _ := time.Parse("2006-01-02", today) todayTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
stats := config.UsageStats{ stats := config.UsageStats{
ProviderID: providerID, ProviderID: providerID,

View File

@@ -1,11 +1,15 @@
package service package service
import ( import (
"github.com/google/uuid" "errors"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"github.com/google/uuid"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
) )
type modelService struct { type modelService struct {
@@ -108,7 +112,11 @@ func (s *modelService) Delete(id string) error {
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error { func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName) existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil { if err != nil {
return nil // 未找到,不重复 if errors.Is(err, gorm.ErrRecordNotFound) {
return nil // 未找到,不重复
}
return err
} }
if excludeID != "" && existing.ID == excludeID { if excludeID != "" && existing.ID == excludeID {
return nil // 排除自身 return nil // 排除自身

View File

@@ -3,10 +3,10 @@ package service
import ( import (
"strings" "strings"
"nex/backend/pkg/modelid"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/pkg/modelid"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
) )

View File

@@ -4,10 +4,11 @@ import (
"strings" "strings"
"sync" "sync"
"go.uber.org/zap"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger" pkglogger "nex/backend/pkg/logger"
) )
@@ -34,7 +35,9 @@ func NewRoutingCache(
func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) { func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
if v, ok := c.providers.Load(id); ok { if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil if provider, ok := v.(*domain.Provider); ok {
return provider, nil
}
} }
provider, err := c.providerRepo.GetByID(id) provider, err := c.providerRepo.GetByID(id)
@@ -43,7 +46,9 @@ func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
} }
if v, ok := c.providers.Load(id); ok { if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil if provider, ok := v.(*domain.Provider); ok {
return provider, nil
}
} }
c.providers.Store(id, provider) c.providers.Store(id, provider)
@@ -54,7 +59,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
key := providerID + "/" + modelName key := providerID + "/" + modelName
if v, ok := c.models.Load(key); ok { if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil if model, ok := v.(*domain.Model); ok {
return model, nil
}
} }
model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName) model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName)
@@ -63,7 +70,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
} }
if v, ok := c.models.Load(key); ok { if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil if model, ok := v.(*domain.Model); ok {
return model, nil
}
} }
c.models.Store(key, model) c.models.Store(key, model)
@@ -97,7 +106,12 @@ func (c *RoutingCache) invalidateModelsByProvider(providerID string) {
prefix := providerID + "/" prefix := providerID + "/"
count := 0 count := 0
c.models.Range(func(key, value interface{}) bool { c.models.Range(func(key, value interface{}) bool {
if strings.HasPrefix(key.(string), prefix) { keyStr, ok := key.(string)
if !ok {
return true
}
if strings.HasPrefix(keyStr, prefix) {
c.models.Delete(key) c.models.Delete(key)
count++ count++
} }

View File

@@ -5,11 +5,11 @@ import (
"sync" "sync"
"testing" "testing"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/domain"
) )
type mockModelRepo struct { type mockModelRepo struct {
@@ -189,7 +189,8 @@ func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
var openaiCount, anthropicCount int var openaiCount, anthropicCount int
cache.models.Range(func(key, value interface{}) bool { cache.models.Range(func(key, value interface{}) bool {
if key.(string) == "anthropic/claude" { keyStr, ok := key.(string)
if ok && keyStr == "anthropic/claude" {
anthropicCount++ anthropicCount++
} }
return true return true

View File

@@ -1,9 +1,8 @@
package service package service
import ( import (
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain" "nex/backend/internal/domain"
appErrors "nex/backend/pkg/errors"
) )
type routingService struct { type routingService struct {

View File

@@ -3,12 +3,12 @@ package service
import ( import (
"testing" "testing"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
) )
func TestProviderService_Update(t *testing.T) { func TestProviderService_Update(t *testing.T) {
@@ -133,7 +133,9 @@ func TestStatsService_Aggregate_Default(t *testing.T) {
totalCount := 0 totalCount := 0
for _, r := range result { for _, r := range result {
totalCount += r["request_count"].(int) count, ok := r["request_count"].(int)
require.True(t, ok)
totalCount += count
} }
assert.Equal(t, 15, totalCount) assert.Equal(t, 15, totalCount)
} }

View File

@@ -5,6 +5,9 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -13,8 +16,6 @@ import (
testHelpers "nex/backend/tests" testHelpers "nex/backend/tests"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
) )
@@ -134,7 +135,7 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache) svc := NewModelService(modelRepo, providerRepo, cache)
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"} model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
@@ -148,7 +149,7 @@ func TestProviderService_Create_InvalidID(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache) svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
@@ -160,7 +161,7 @@ func TestProviderService_Create_ValidID(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache) svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
@@ -176,7 +177,7 @@ func TestModelService_Update_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache) svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
@@ -202,7 +203,7 @@ func TestModelService_Update_ModelNotFound(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache) svc := NewModelService(modelRepo, providerRepo, cache)
err := svc.Update("nonexistent-id", map[string]interface{}{ err := svc.Update("nonexistent-id", map[string]interface{}{
@@ -215,7 +216,7 @@ func TestModelService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache) svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
@@ -241,7 +242,7 @@ func TestProviderService_Update_ImmutableID(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache) svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
@@ -259,7 +260,7 @@ func TestProviderService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache) svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
@@ -318,7 +319,8 @@ func TestStatsService_Aggregate_ByModel(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db) statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, zap.NewNop()); svc := NewStatsService(statsRepo, buffer) buffer := NewStatsBuffer(statsRepo, zap.NewNop())
svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "model") result := svc.Aggregate(tt.stats, "model")
@@ -379,7 +381,8 @@ func TestStatsService_Aggregate_ByDate(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db) statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, zap.NewNop()); svc := NewStatsService(statsRepo, buffer) buffer := NewStatsBuffer(statsRepo, zap.NewNop())
svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "date") result := svc.Aggregate(tt.stats, "date")
@@ -448,7 +451,7 @@ func TestProviderService_List_APIKeyNotMasked(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache) svc := NewProviderService(repo, modelRepo, cache)
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"} provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
@@ -474,7 +477,7 @@ func TestModelService_ConcurrentCreate(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache) svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))

View File

@@ -6,9 +6,10 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"nex/backend/internal/repository"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/repository"
pkglogger "nex/backend/pkg/logger" pkglogger "nex/backend/pkg/logger"
) )
@@ -67,13 +68,21 @@ func (b *StatsBuffer) Increment(providerID, modelName string) {
var counter *int64 var counter *int64
if v, ok := b.counters.Load(key); ok { if v, ok := b.counters.Load(key); ok {
counter = v.(*int64) if existing, ok := v.(*int64); ok {
counter = existing
} else {
return
}
} else { } else {
val := int64(0) val := int64(0)
counter = &val counter = &val
actual, loaded := b.counters.LoadOrStore(key, counter) actual, loaded := b.counters.LoadOrStore(key, counter)
if loaded { if loaded {
counter = actual.(*int64) existing, ok := actual.(*int64)
if !ok {
return
}
counter = existing
} }
} }
@@ -117,13 +126,20 @@ func (b *StatsBuffer) flush() {
var entries []statEntry var entries []statEntry
b.counters.Range(func(key, value interface{}) bool { b.counters.Range(func(key, value interface{}) bool {
keyStr := key.(string) keyStr, ok := key.(string)
if !ok {
return true
}
parts := strings.Split(keyStr, "/") parts := strings.Split(keyStr, "/")
if len(parts) != 3 { if len(parts) != 3 {
return true return true
} }
counter := value.(*int64) counter, ok := value.(*int64)
if !ok {
return true
}
count := atomic.SwapInt64(counter, 0) count := atomic.SwapInt64(counter, 0)
if count > 0 { if count > 0 {
@@ -143,8 +159,17 @@ func (b *StatsBuffer) flush() {
success := 0 success := 0
for _, entry := range entries { for _, entry := range entries {
date, _ := time.Parse("2006-01-02", entry.date) date, err := time.Parse("2006-01-02", entry.date)
err := b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count)) if err != nil {
b.logger.Error("解析统计日期失败",
zap.String("provider_id", entry.providerID),
zap.String("model_name", entry.modelName),
zap.String("date", entry.date),
zap.Error(err))
continue
}
err = b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
if err != nil { if err != nil {
b.logger.Error("批量更新统计失败", b.logger.Error("批量更新统计失败",
zap.String("provider_id", entry.providerID), zap.String("provider_id", entry.providerID),
@@ -154,8 +179,10 @@ func (b *StatsBuffer) flush() {
key := entry.providerID + "/" + entry.modelName + "/" + entry.date key := entry.providerID + "/" + entry.modelName + "/" + entry.date
if v, ok := b.counters.Load(key); ok { if v, ok := b.counters.Load(key); ok {
counter := v.(*int64) counter, ok := v.(*int64)
atomic.AddInt64(counter, entry.count) if ok {
atomic.AddInt64(counter, entry.count)
}
} }
} else { } else {
success++ success++

View File

@@ -7,10 +7,10 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/domain"
) )
type mockStatsRepo struct { type mockStatsRepo struct {
@@ -58,8 +58,10 @@ func TestStatsBuffer_Increment(t *testing.T) {
var count int64 var count int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
count += atomic.LoadInt64(counter) if ok {
count += atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(3), count) assert.Equal(t, int64(3), count)
@@ -82,8 +84,10 @@ func TestStatsBuffer_ConcurrentIncrement(t *testing.T) {
var count int64 var count int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
count = atomic.LoadInt64(counter) if ok {
count = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(100), count) assert.Equal(t, int64(100), count)
@@ -161,8 +165,10 @@ func TestStatsBuffer_SwapInt64(t *testing.T) {
var beforeCount int64 var beforeCount int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
beforeCount = atomic.LoadInt64(counter) if ok {
beforeCount = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(2), beforeCount) assert.Equal(t, int64(2), beforeCount)
@@ -171,8 +177,10 @@ func TestStatsBuffer_SwapInt64(t *testing.T) {
var afterCount int64 var afterCount int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
afterCount = atomic.LoadInt64(counter) if ok {
afterCount = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(0), afterCount) assert.Equal(t, int64(0), afterCount)
@@ -190,8 +198,10 @@ func TestStatsBuffer_FailRetry(t *testing.T) {
var count int64 var count int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
count = atomic.LoadInt64(counter) if ok {
count = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(2), count) assert.Equal(t, int64(2), count)

View File

@@ -1,6 +1,7 @@
package errors package errors
import ( import (
stderrors "errors"
"fmt" "fmt"
"net/http" "net/http"
) )
@@ -70,22 +71,11 @@ func AsAppError(err error) (*AppError, bool) {
if err == nil { if err == nil {
return nil, false return nil, false
} }
var appErr *AppError
if ok := is(err, &appErr); ok {
return appErr, true
}
return nil, false
}
func is(err error, target interface{}) bool { var appErr *AppError
// 简单的类型断言 if !stderrors.As(err, &appErr) {
if e, ok := err.(*AppError); ok { return nil, false
// 直接赋值
switch t := target.(type) {
case **AppError:
*t = e
return true
}
} }
return false
return appErr, true
} }

View File

@@ -104,7 +104,8 @@ func TestPredefinedErrors(t *testing.T) {
func TestAsAppError(t *testing.T) { func TestAsAppError(t *testing.T) {
t.Run("nil输入", func(t *testing.T) { t.Run("nil输入", func(t *testing.T) {
_, ok := AsAppError(nil) appErr, ok := AsAppError(nil)
assert.Nil(t, appErr)
assert.False(t, ok) assert.False(t, ok)
}) })
@@ -122,7 +123,8 @@ func TestAsAppError(t *testing.T) {
}) })
t.Run("非AppError类型", func(t *testing.T) { t.Run("非AppError类型", func(t *testing.T) {
_, ok := AsAppError(errors.New("普通错误")) appErr, ok := AsAppError(errors.New("普通错误"))
assert.Nil(t, appErr)
assert.False(t, ok) assert.False(t, ok)
}) })
} }

View File

@@ -19,7 +19,7 @@ func TestNew_StdoutOnly(t *testing.T) {
func TestNew_WithFileOutput(t *testing.T) { func TestNew_WithFileOutput(t *testing.T) {
dir := filepath.Join(os.TempDir(), "nex-logger-test") dir := filepath.Join(os.TempDir(), "nex-logger-test")
os.MkdirAll(dir, 0755) require.NoError(t, os.MkdirAll(dir, 0o755))
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
logger, err := New(Config{ logger, err := New(Config{
@@ -81,7 +81,7 @@ func TestParseLevel(t *testing.T) {
{"info", true}, {"info", true},
{"warn", true}, {"warn", true},
{"error", true}, {"error", true},
{"", true}, // 默认为 info {"", true}, // 默认为 info
{"invalid", true}, // 默认为 info {"invalid", true}, // 默认为 info
} }
for _, tt := range tests { for _, tt := range tests {

View File

@@ -22,9 +22,9 @@ func newRotateWriter(cfg Config) *lumberjack.Logger {
return &lumberjack.Logger{ return &lumberjack.Logger{
Filename: logFilePath(cfg.Path), Filename: logFilePath(cfg.Path),
MaxSize: maxSize, // MB MaxSize: maxSize, // MB
MaxBackups: maxBackups, MaxBackups: maxBackups,
MaxAge: maxAge, // days MaxAge: maxAge, // days
Compress: cfg.Compress, Compress: cfg.Compress,
} }
} }

View File

@@ -6,10 +6,10 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"nex/backend/internal/config"
) )
func TestLoadConfig_DefaultValues(t *testing.T) { func TestLoadConfig_DefaultValues(t *testing.T) {
@@ -72,7 +72,7 @@ log:
max_age: 7 max_age: 7
compress: false compress: false
` `
err := os.WriteFile(configPath, []byte(yamlContent), 0644) err := os.WriteFile(configPath, []byte(yamlContent), 0o600)
require.NoError(t, err) require.NoError(t, err)
cfg, err := config.LoadConfigFromPath(configPath) cfg, err := config.LoadConfigFromPath(configPath)
@@ -103,7 +103,7 @@ server:
log: log:
level: warn level: warn
` `
err := os.WriteFile(configPath, []byte(yamlContent), 0644) err := os.WriteFile(configPath, []byte(yamlContent), 0o600)
require.NoError(t, err) require.NoError(t, err)
t.Setenv("NEX_SERVER_PORT", "9000") t.Setenv("NEX_SERVER_PORT", "9000")
@@ -147,7 +147,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
} }
defer func() { defer func() {
if originalConfig != nil { if originalConfig != nil {
_ = os.WriteFile(configPath, originalConfig, 0644) require.NoError(t, os.WriteFile(configPath, originalConfig, 0o600))
} }
}() }()

View File

@@ -11,20 +11,21 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/provider"
"nex/backend/internal/repository"
"nex/backend/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"gorm.io/gorm" "gorm.io/gorm"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
openaiConv "nex/backend/internal/conversion/openai" openaiConv "nex/backend/internal/conversion/openai"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/provider"
"nex/backend/internal/repository"
"nex/backend/internal/service"
) )
func init() { func init() {
@@ -39,7 +40,8 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 默认返回成功,由各测试 case 覆盖 // 默认返回成功,由各测试 case 覆盖
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"error":"not mocked"}`)) _, err := w.Write([]byte(`{"error":"not mocked"}`))
require.NoError(t, err)
})) }))
db := setupTestDB(t) db := setupTestDB(t)
@@ -124,7 +126,6 @@ func createProviderAndModel(t *testing.T, r *gin.Engine, providerID, protocol, m
require.Equal(t, 201, w.Code) require.Equal(t, 201, w.Code)
modelBody, _ := json.Marshal(map[string]string{ modelBody, _ := json.Marshal(map[string]string{
"provider_id": providerID, "provider_id": providerID,
"model_name": modelName, "model_name": modelName,
}) })
@@ -143,9 +144,10 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
// 配置上游返回 Anthropic 格式响应 // 配置上游返回 Anthropic 格式响应
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求被转换为 Anthropic 格式 // 验证请求被转换为 Anthropic 格式
body, _ := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any var req map[string]any
json.Unmarshal(body, &req) require.NoError(t, json.Unmarshal(body, &req))
assert.Equal(t, "/v1/messages", r.URL.Path) assert.Equal(t, "/v1/messages", r.URL.Path)
assert.Contains(t, r.Header.Get("Content-Type"), "application/json") assert.Contains(t, r.Header.Get("Content-Type"), "application/json")
@@ -166,7 +168,7 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
}, },
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp) require.NoError(t, json.NewEncoder(w).Encode(resp))
}) })
createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL) createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
@@ -189,13 +191,16 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var resp map[string]any var resp map[string]any
json.Unmarshal(w.Body.Bytes(), &resp) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "chat.completion", resp["object"]) assert.Equal(t, "chat.completion", resp["object"])
choices := resp["choices"].([]any) choices, ok := resp["choices"].([]any)
require.True(t, ok)
require.Len(t, choices, 1) require.Len(t, choices, 1)
choice := choices[0].(map[string]any) choice, ok := choices[0].(map[string]any)
msg := choice["message"].(map[string]any) require.True(t, ok)
msg, ok := choice["message"].(map[string]any)
require.True(t, ok)
assert.Contains(t, msg["content"], "Hello from Anthropic!") assert.Contains(t, msg["content"], "Hello from Anthropic!")
} }
@@ -203,9 +208,10 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
r, _, upstream := setupConversionTest(t) r, _, upstream := setupConversionTest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any var req map[string]any
json.Unmarshal(body, &req) require.NoError(t, json.Unmarshal(body, &req))
assert.Equal(t, "/chat/completions", r.URL.Path) assert.Equal(t, "/chat/completions", r.URL.Path)
assert.Contains(t, r.Header.Get("Authorization"), "Bearer test-key") assert.Contains(t, r.Header.Get("Authorization"), "Bearer test-key")
@@ -229,7 +235,7 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
}, },
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp) require.NoError(t, json.NewEncoder(w).Encode(resp))
}) })
createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL) createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
@@ -252,12 +258,14 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var resp map[string]any var resp map[string]any
json.Unmarshal(w.Body.Bytes(), &resp) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "message", resp["type"]) assert.Equal(t, "message", resp["type"])
content := resp["content"].([]any) content, ok := resp["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1) require.Len(t, content, 1)
block := content[0].(map[string]any) block, ok2 := content[0].(map[string]any)
require.True(t, ok2)
assert.Contains(t, block["text"], "Hello from OpenAI!") assert.Contains(t, block["text"], "Hello from OpenAI!")
} }
@@ -269,21 +277,23 @@ func TestConversion_OpenAIToOpenAI_Passthrough(t *testing.T) {
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/chat/completions", r.URL.Path) assert.Equal(t, "/chat/completions", r.URL.Path)
body, _ := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any var req map[string]any
json.Unmarshal(body, &req) require.NoError(t, json.Unmarshal(body, &req))
// Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名 // Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名
assert.Equal(t, "gpt-4", req["model"]) assert.Equal(t, "gpt-4", req["model"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
// 上游返回上游模型名 // 上游返回上游模型名
w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`)) _, err = w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`))
require.NoError(t, err)
}) })
createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL) createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
reqBody := map[string]any{ reqBody := map[string]any{
"model": "openai_p/gpt-4", // 客户端发送统一 ID "model": "openai_p/gpt-4", // 客户端发送统一 ID
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
} }
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(reqBody)
@@ -304,21 +314,23 @@ func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) {
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/messages", r.URL.Path) assert.Equal(t, "/v1/messages", r.URL.Path)
body, _ := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any var req map[string]any
json.Unmarshal(body, &req) require.NoError(t, json.Unmarshal(body, &req))
// Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名 // Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名
assert.Equal(t, "claude-3-opus", req["model"]) assert.Equal(t, "claude-3-opus", req["model"])
// 上游返回上游模型名 // 上游返回上游模型名
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`)) _, err = w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`))
require.NoError(t, err)
}) })
createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL) createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
reqBody := map[string]any{ reqBody := map[string]any{
"model": "anthropic_p/claude-3-opus", // 客户端发送统一 ID "model": "anthropic_p/claude-3-opus", // 客户端发送统一 ID
"max_tokens": 1024, "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
} }
@@ -352,7 +364,8 @@ func TestConversion_OpenAIToAnthropic_Stream(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
if f, ok := w.(http.Flusher); ok { if f, ok := w.(http.Flusher); ok {
f.Flush() f.Flush()
} }
@@ -393,7 +406,8 @@ func TestConversion_AnthropicToOpenAI_Stream(t *testing.T) {
"data: [DONE]\n\n", "data: [DONE]\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
if f, ok := w.(http.Flusher); ok { if f, ok := w.(http.Flusher); ok {
f.Flush() f.Flush()
} }
@@ -447,11 +461,13 @@ func TestConversion_Models_CrossProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
var anthropicResp map[string]any var anthropicResp map[string]any
json.Unmarshal(anthropicBody, &anthropicResp) require.NoError(t, json.Unmarshal(anthropicBody, &anthropicResp))
data := anthropicResp["data"].([]any) data, okd := anthropicResp["data"].([]any)
require.True(t, okd)
assert.Len(t, data, 2) assert.Len(t, data, 2)
first := data[0].(map[string]any) first, okf := data[0].(map[string]any)
require.True(t, okf)
assert.Equal(t, "gpt-4", first["id"]) assert.Equal(t, "gpt-4", first["id"])
assert.Equal(t, "model", first["type"]) assert.Equal(t, "model", first["type"])
@@ -466,11 +482,12 @@ func TestConversion_Models_CrossProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
var openaiResp map[string]any var openaiResp map[string]any
json.Unmarshal(openaiBody, &err) require.NoError(t, json.Unmarshal(openaiBody, &openaiResp))
json.Unmarshal(openaiBody, &openaiResp) oaiData, oki := openaiResp["data"].([]any)
oaiData := openaiResp["data"].([]any) require.True(t, oki)
assert.Len(t, oaiData, 1) assert.Len(t, oaiData, 1)
firstOai := oaiData[0].(map[string]any) firstOai, okf2 := oaiData[0].(map[string]any)
require.True(t, okf2)
assert.Equal(t, "claude-3-opus", firstOai["id"]) assert.Equal(t, "claude-3-opus", firstOai["id"])
} }
@@ -537,7 +554,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
require.Equal(t, 201, w.Code) require.Equal(t, 201, w.Code)
var created map[string]any var created map[string]any
json.Unmarshal(w.Body.Bytes(), &created) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &created))
assert.Equal(t, "sk-test", created["api_key"]) assert.Equal(t, "sk-test", created["api_key"])
// 获取时应包含 protocol // 获取时应包含 protocol
@@ -547,7 +564,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var fetched map[string]any var fetched map[string]any
json.Unmarshal(w.Body.Bytes(), &fetched) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &fetched))
assert.Equal(t, "anthropic", fetched["protocol"]) assert.Equal(t, "anthropic", fetched["protocol"])
} }
@@ -570,11 +587,13 @@ func TestConversion_ProviderDefaultProtocol(t *testing.T) {
require.Equal(t, 201, w.Code) require.Equal(t, 201, w.Code)
var created map[string]any var created map[string]any
json.Unmarshal(w.Body.Bytes(), &created) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &created))
assert.Equal(t, "openai", created["protocol"]) assert.Equal(t, "openai", created["protocol"])
} }
// Suppress unused imports // Suppress unused imports
var _ = fmt.Sprintf var (
var _ = strings.Contains _ = fmt.Sprintf
var _ = time.Second _ = strings.Contains
_ = time.Second
)

View File

@@ -12,19 +12,20 @@ import (
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic" "nex/backend/internal/conversion/anthropic"
openaiConv "nex/backend/internal/conversion/openai"
"nex/backend/internal/handler" "nex/backend/internal/handler"
"nex/backend/internal/handler/middleware" "nex/backend/internal/handler/middleware"
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/internal/service" "nex/backend/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
openaiConv "nex/backend/internal/conversion/openai"
) )
func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) { func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
@@ -33,7 +34,8 @@ func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"error":"not mocked"}`)) _, err := w.Write([]byte(`{"error":"not mocked"}`))
require.NoError(t, err)
})) }))
db := setupTestDB(t) db := setupTestDB(t)
@@ -115,11 +117,12 @@ func parseSSEEvents(body string) []map[string]string {
var currentEvent, currentData string var currentEvent, currentData string
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if strings.HasPrefix(line, "event: ") { switch {
case strings.HasPrefix(line, "event: "):
currentEvent = strings.TrimPrefix(line, "event: ") currentEvent = strings.TrimPrefix(line, "event: ")
} else if strings.HasPrefix(line, "data: ") { case strings.HasPrefix(line, "data: "):
currentData = strings.TrimPrefix(line, "data: ") currentData = strings.TrimPrefix(line, "data: ")
} else if line == "" && (currentEvent != "" || currentData != "") { case line == "" && (currentEvent != "" || currentData != ""):
events = append(events, map[string]string{ events = append(events, map[string]string{
"event": currentEvent, "event": currentEvent,
"data": currentData, "data": currentData,
@@ -157,21 +160,21 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) {
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.Equal(t, "/chat/completions", req.URL.Path) assert.Equal(t, "/chat/completions", req.URL.Path)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-001", "id": "chatcmpl-e2e-001",
"object": "chat.completion", "object": "chat.completion",
"created": 1700000000, "created": 1700000000,
"model": "gpt-4o", "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
"message": map[string]any{"role": "assistant", "content": "你好我是AI助手。"}, "message": map[string]any{"role": "assistant", "content": "你好我是AI助手。"},
"finish_reason": "stop", "finish_reason": "stop",
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{ "usage": map[string]any{
"prompt_tokens": 15, "completion_tokens": 10, "total_tokens": 25, "prompt_tokens": 15, "completion_tokens": 10, "total_tokens": 25,
}, },
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -210,21 +213,23 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) {
func TestE2E_OpenAI_NonStream_MultiTurn(t *testing.T) { func TestE2E_OpenAI_NonStream_MultiTurn(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
msgs := reqBody["messages"].([]any) msgs, ok := reqBody["messages"].([]any)
require.True(t, ok)
assert.GreaterOrEqual(t, len(msgs), 3) assert.GreaterOrEqual(t, len(msgs), 3)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-002", "object": "chat.completion", "created": 1700000001, "model": "gpt-4o", "id": "chatcmpl-e2e-002", "object": "chat.completion", "created": 1700000001, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "message": map[string]any{"role": "assistant", "content": "Go语言的interface是隐式实现的。"}, "index": 0, "message": map[string]any{"role": "assistant", "content": "Go语言的interface是隐式实现的。"},
"finish_reason": "stop", "logprobs": nil, "finish_reason": "stop", "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120}, "usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -252,7 +257,7 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-004", "object": "chat.completion", "created": 1700000003, "model": "gpt-4o", "id": "chatcmpl-e2e-004", "object": "chat.completion", "created": 1700000003, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
@@ -272,7 +277,7 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 80, "completion_tokens": 18, "total_tokens": 98}, "usage": map[string]any{"prompt_tokens": 80, "completion_tokens": 18, "total_tokens": 98},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -286,9 +291,9 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
"function": map[string]any{ "function": map[string]any{
"name": "get_weather", "description": "获取天气", "name": "get_weather", "description": "获取天气",
"parameters": map[string]any{ "parameters": map[string]any{
"type": "object", "type": "object",
"properties": map[string]any{"city": map[string]any{"type": "string"}}, "properties": map[string]any{"city": map[string]any{"type": "string"}},
"required": []string{"city"}, "required": []string{"city"},
}, },
}, },
}}, }},
@@ -319,22 +324,22 @@ func TestE2E_OpenAI_NonStream_MaxTokens_Length(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-014", "object": "chat.completion", "created": 1700000014, "model": "gpt-4o", "id": "chatcmpl-e2e-014", "object": "chat.completion", "created": 1700000014, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
"message": map[string]any{"role": "assistant", "content": "人工智能起源于1950年代..."}, "message": map[string]any{"role": "assistant", "content": "人工智能起源于1950年代..."},
"finish_reason": "length", "finish_reason": "length",
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 30, "total_tokens": 50}, "usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 30, "total_tokens": 50},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}}, "messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}},
"max_tokens": 30, "max_tokens": 30,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -353,11 +358,11 @@ func TestE2E_OpenAI_NonStream_UsageWithReasoning(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-022", "object": "chat.completion", "created": 1700000022, "model": "o3", "id": "chatcmpl-e2e-022", "object": "chat.completion", "created": 1700000022, "model": "o3",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
"message": map[string]any{"role": "assistant", "content": "答案是61。"}, "message": map[string]any{"role": "assistant", "content": "答案是61。"},
"finish_reason": "stop", "finish_reason": "stop",
"logprobs": nil, "logprobs": nil,
}}, }},
@@ -365,12 +370,12 @@ func TestE2E_OpenAI_NonStream_UsageWithReasoning(t *testing.T) {
"prompt_tokens": 35, "completion_tokens": 48, "total_tokens": 83, "prompt_tokens": 35, "completion_tokens": 48, "total_tokens": 83,
"completion_tokens_details": map[string]any{"reasoning_tokens": 20}, "completion_tokens_details": map[string]any{"reasoning_tokens": 20},
}, },
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "o3", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "o3", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/o3", "model": "openai_p/o3",
"messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}}, "messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -393,12 +398,12 @@ func TestE2E_OpenAI_NonStream_Refusal(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-007", "object": "chat.completion", "created": 1700000007, "model": "gpt-4o", "id": "chatcmpl-e2e-007", "object": "chat.completion", "created": 1700000007, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
"message": map[string]any{ "message": map[string]any{
"role": "assistant", "role": "assistant",
"content": nil, "content": nil,
"refusal": "抱歉,我无法提供涉及危险活动的信息。", "refusal": "抱歉,我无法提供涉及危险活动的信息。",
}, },
@@ -406,12 +411,12 @@ func TestE2E_OpenAI_NonStream_Refusal(t *testing.T) {
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 35, "total_tokens": 47}, "usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 35, "total_tokens": 47},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "做坏事"}}, "messages": []map[string]any{{"role": "user", "content": "做坏事"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -453,9 +458,9 @@ func TestE2E_OpenAI_Stream_Text(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "你好"}}, "messages": []map[string]any{{"role": "user", "content": "你好"}},
"stream": true, "stream": true,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
@@ -497,14 +502,14 @@ func TestE2E_OpenAI_Stream_ToolCalls(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"type": "function", "type": "function",
"function": map[string]any{ "function": map[string]any{
"name": "get_weather", "description": "获取天气", "name": "get_weather", "description": "获取天气",
"parameters": map[string]any{ "parameters": map[string]any{
"type": "object", "type": "object",
"properties": map[string]any{"city": map[string]any{"type": "string"}}, "properties": map[string]any{"city": map[string]any{"type": "string"}},
}, },
}, },
@@ -546,9 +551,9 @@ func TestE2E_OpenAI_Stream_WithUsage(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "hi"}}, "messages": []map[string]any{{"role": "user", "content": "hi"}},
"stream": true, "stream": true,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
@@ -569,14 +574,14 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_001", "type": "message", "role": "assistant", "id": "msg_e2e_001", "type": "message", "role": "assistant",
"content": []map[string]any{ "content": []map[string]any{
{"type": "text", "text": "你好我是Claude由Anthropic开发的AI助手。"}, {"type": "text", "text": "你好我是Claude由Anthropic开发的AI助手。"},
}, },
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 15, "output_tokens": 25}, "usage": map[string]any{"input_tokens": 15, "output_tokens": 25},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -611,24 +616,25 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) {
func TestE2E_Anthropic_NonStream_WithSystem(t *testing.T) { func TestE2E_Anthropic_NonStream_WithSystem(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
assert.NotNil(t, reqBody["system"]) assert.NotNil(t, reqBody["system"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_003", "type": "message", "role": "assistant", "id": "msg_e2e_003", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "递归是函数调用自身。"}}, "content": []map[string]any{{"type": "text", "text": "递归是函数调用自身。"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 30, "output_tokens": 15}, "usage": map[string]any{"input_tokens": 30, "output_tokens": 15},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"system": "你是编程助手", "system": "你是编程助手",
"messages": []map[string]any{{"role": "user", "content": "什么是递归?"}}, "messages": []map[string]any{{"role": "user", "content": "什么是递归?"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -643,7 +649,7 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_009", "type": "message", "role": "assistant", "id": "msg_e2e_009", "type": "message", "role": "assistant",
"content": []map[string]any{{ "content": []map[string]any{{
"type": "tool_use", "id": "toolu_e2e_009", "name": "get_weather", "type": "tool_use", "id": "toolu_e2e_009", "name": "get_weather",
@@ -651,7 +657,7 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
}}, }},
"model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 180, "output_tokens": 42}, "usage": map[string]any{"input_tokens": 180, "output_tokens": 42},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -661,9 +667,9 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
"tools": []map[string]any{{ "tools": []map[string]any{{
"name": "get_weather", "description": "获取天气", "name": "get_weather", "description": "获取天气",
"input_schema": map[string]any{ "input_schema": map[string]any{
"type": "object", "type": "object",
"properties": map[string]any{"city": map[string]any{"type": "string"}}, "properties": map[string]any{"city": map[string]any{"type": "string"}},
"required": []string{"city"}, "required": []string{"city"},
}, },
}}, }},
"tool_choice": map[string]any{"type": "auto"}, "tool_choice": map[string]any{"type": "auto"},
@@ -689,7 +695,7 @@ func TestE2E_Anthropic_NonStream_Thinking(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_018", "type": "message", "role": "assistant", "id": "msg_e2e_018", "type": "message", "role": "assistant",
"content": []map[string]any{ "content": []map[string]any{
{"type": "thinking", "thinking": "这是一个逻辑推理问题..."}, {"type": "thinking", "thinking": "这是一个逻辑推理问题..."},
@@ -697,7 +703,7 @@ func TestE2E_Anthropic_NonStream_Thinking(t *testing.T) {
}, },
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 95, "output_tokens": 280}, "usage": map[string]any{"input_tokens": 95, "output_tokens": 280},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -724,12 +730,12 @@ func TestE2E_Anthropic_NonStream_MaxTokens(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_016", "type": "message", "role": "assistant", "id": "msg_e2e_016", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "人工智能起源于..."}}, "content": []map[string]any{{"type": "text", "text": "人工智能起源于..."}},
"model": "claude-opus-4-7", "stop_reason": "max_tokens", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "max_tokens", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 22, "output_tokens": 20}, "usage": map[string]any{"input_tokens": 22, "output_tokens": 20},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -752,18 +758,18 @@ func TestE2E_Anthropic_NonStream_StopSequence(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_017", "type": "message", "role": "assistant", "id": "msg_e2e_017", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "1\n2\n3\n4\n"}}, "content": []map[string]any{{"type": "text", "text": "1\n2\n3\n4\n"}},
"model": "claude-opus-4-7", "stop_reason": "stop_sequence", "stop_sequence": "5", "model": "claude-opus-4-7", "stop_reason": "stop_sequence", "stop_sequence": "5",
"usage": map[string]any{"input_tokens": 22, "output_tokens": 10}, "usage": map[string]any{"input_tokens": 22, "output_tokens": 10},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}}, "messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
"stop_sequences": []string{"5"}, "stop_sequences": []string{"5"},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -781,19 +787,20 @@ func TestE2E_Anthropic_NonStream_StopSequence(t *testing.T) {
func TestE2E_Anthropic_NonStream_MetadataUserID(t *testing.T) { func TestE2E_Anthropic_NonStream_MetadataUserID(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
metadata, _ := reqBody["metadata"].(map[string]any) metadata, _ := reqBody["metadata"].(map[string]any)
assert.Equal(t, "user_12345", metadata["user_id"]) assert.Equal(t, "user_12345", metadata["user_id"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_026", "type": "message", "role": "assistant", "id": "msg_e2e_026", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "你好!"}}, "content": []map[string]any{{"type": "text", "text": "你好!"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 12, "output_tokens": 5}, "usage": map[string]any{"input_tokens": 12, "output_tokens": 5},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -814,21 +821,21 @@ func TestE2E_Anthropic_NonStream_UsageWithCache(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_025", "type": "message", "role": "assistant", "id": "msg_e2e_025", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "你好!"}}, "content": []map[string]any{{"type": "text", "text": "你好!"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{ "usage": map[string]any{
"input_tokens": 25, "output_tokens": 5, "input_tokens": 25, "output_tokens": 5,
"cache_creation_input_tokens": 15, "cache_read_input_tokens": 0, "cache_creation_input_tokens": 15, "cache_read_input_tokens": 0,
}, },
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"system": []map[string]any{{"type": "text", "text": "你是编程助手。"}}, "system": []map[string]any{{"type": "text", "text": "你是编程助手。"}},
"messages": []map[string]any{{"role": "user", "content": "你好"}}, "messages": []map[string]any{{"role": "user", "content": "你好"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -864,7 +871,8 @@ func TestE2E_Anthropic_Stream_Text(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
@@ -874,7 +882,7 @@ func TestE2E_Anthropic_Stream_Text(t *testing.T) {
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "你好"}}, "messages": []map[string]any{{"role": "user", "content": "你好"}},
"stream": true, "stream": true,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
@@ -922,7 +930,7 @@ func TestE2E_Anthropic_Stream_Thinking(t *testing.T) {
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 4096, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 4096,
"messages": []map[string]any{{"role": "user", "content": "1+1=?"}}, "messages": []map[string]any{{"role": "user", "content": "1+1=?"}},
"thinking": map[string]any{"type": "enabled", "budget_tokens": 1024}, "thinking": map[string]any{"type": "enabled", "budget_tokens": 1024},
"stream": true, "stream": true,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
@@ -961,14 +969,14 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_RequestFormat(t *testing.T) {
json.NewEncoder(w).Encode(map[string]any{ json.NewEncoder(w).Encode(map[string]any{
"id": "msg_cross_001", "type": "message", "role": "assistant", "id": "msg_cross_001", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "跨协议响应"}}, "content": []map[string]any{{"type": "text", "text": "跨协议响应"}},
"model": "claude-model", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-model", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 10, "output_tokens": 5}, "usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-model", "model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "Hello"}}, "messages": []map[string]any{{"role": "user", "content": "Hello"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1050,9 +1058,9 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_Stream(t *testing.T) {
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-model", "model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "Hello"}}, "messages": []map[string]any{{"role": "user", "content": "Hello"}},
"stream": true, "stream": true,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
@@ -1092,7 +1100,7 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream(t *testing.T) {
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4", "max_tokens": 1024, "model": "openai_p/gpt-4", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "Hello"}}, "messages": []map[string]any{{"role": "user", "content": "Hello"}},
"stream": true, "stream": true,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
@@ -1128,7 +1136,7 @@ func TestE2E_OpenAI_ErrorResponse(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "nonexistent", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "nonexistent", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/nonexistent", "model": "openai_p/nonexistent",
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1183,11 +1191,11 @@ func TestE2E_OpenAI_NonStream_ParallelToolCalls(t *testing.T) {
"content": nil, "content": nil,
"tool_calls": []map[string]any{ "tool_calls": []map[string]any{
{ {
"id": "call_ptc_1", "type": "function", "id": "call_ptc_1", "type": "function",
"function": map[string]any{"name": "get_weather", "arguments": `{"city":"北京"}`}, "function": map[string]any{"name": "get_weather", "arguments": `{"city":"北京"}`},
}, },
{ {
"id": "call_ptc_2", "type": "function", "id": "call_ptc_2", "type": "function",
"function": map[string]any{"name": "get_weather", "arguments": `{"city":"上海"}`}, "function": map[string]any{"name": "get_weather", "arguments": `{"city":"上海"}`},
}, },
}, },
@@ -1201,7 +1209,7 @@ func TestE2E_OpenAI_NonStream_ParallelToolCalls(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"type": "function", "type": "function",
@@ -1242,10 +1250,10 @@ func TestE2E_OpenAI_NonStream_StopSequence(t *testing.T) {
json.NewEncoder(w).Encode(map[string]any{ json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-stop", "object": "chat.completion", "created": 1700000060, "model": "gpt-4o", "id": "chatcmpl-e2e-stop", "object": "chat.completion", "created": 1700000060, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
"message": map[string]any{"role": "assistant", "content": "1, 2, 3, 4, "}, "message": map[string]any{"role": "assistant", "content": "1, 2, 3, 4, "},
"finish_reason": "stop", "finish_reason": "stop",
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}, "usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18},
}) })
@@ -1253,9 +1261,9 @@ func TestE2E_OpenAI_NonStream_StopSequence(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}}, "messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
"stop": []string{"5"}, "stop": []string{"5"},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
@@ -1291,7 +1299,7 @@ func TestE2E_OpenAI_NonStream_ContentFilter(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "危险内容"}}, "messages": []map[string]any{{"role": "user", "content": "危险内容"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1353,21 +1361,22 @@ func TestE2E_Anthropic_NonStream_MultiToolUse(t *testing.T) {
func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) { func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
tc, _ := reqBody["tool_choice"].(map[string]any) tc, _ := reqBody["tool_choice"].(map[string]any)
assert.Equal(t, "any", tc["type"]) assert.Equal(t, "any", tc["type"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_tca", "type": "message", "role": "assistant", "id": "msg_e2e_tca", "type": "message", "role": "assistant",
"content": []map[string]any{ "content": []map[string]any{
{"type": "tool_use", "id": "toolu_tca_1", "name": "get_time", "input": map[string]any{"timezone": "Asia/Shanghai"}}, {"type": "tool_use", "id": "toolu_tca_1", "name": "get_time", "input": map[string]any{"timezone": "Asia/Shanghai"}},
}, },
"model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 100, "output_tokens": 30}, "usage": map[string]any{"input_tokens": 100, "output_tokens": 30},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -1397,20 +1406,21 @@ func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) {
func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) { func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
sys, ok := reqBody["system"].([]any) sys, ok := reqBody["system"].([]any)
require.True(t, ok, "system should be an array") require.True(t, ok, "system should be an array")
require.GreaterOrEqual(t, len(sys), 1) require.GreaterOrEqual(t, len(sys), 1)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_asys", "type": "message", "role": "assistant", "id": "msg_e2e_asys", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "已收到多条系统指令。"}}, "content": []map[string]any{{"type": "text", "text": "已收到多条系统指令。"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 50, "output_tokens": 10}, "usage": map[string]any{"input_tokens": 50, "output_tokens": 10},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -1433,21 +1443,22 @@ func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) {
func TestE2E_Anthropic_NonStream_ToolResultMessage(t *testing.T) { func TestE2E_Anthropic_NonStream_ToolResultMessage(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
msgs := reqBody["messages"].([]any) msgs := reqBody["messages"].([]any)
require.GreaterOrEqual(t, len(msgs), 3) require.GreaterOrEqual(t, len(msgs), 3)
lastMsg := msgs[len(msgs)-1].(map[string]any) lastMsg := msgs[len(msgs)-1].(map[string]any)
assert.Equal(t, "user", lastMsg["role"]) assert.Equal(t, "user", lastMsg["role"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_tr", "type": "message", "role": "assistant", "id": "msg_e2e_tr", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "北京当前晴天温度25°C。"}}, "content": []map[string]any{{"type": "text", "text": "北京当前晴天温度25°C。"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 150, "output_tokens": 20}, "usage": map[string]any{"input_tokens": 150, "output_tokens": 20},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -1497,7 +1508,8 @@ func TestE2E_Anthropic_Stream_ToolCalls(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
@@ -1559,7 +1571,7 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_NonStream_ToolCalls(t *testing.T) {
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-model", "model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"type": "function", "type": "function",
@@ -1634,14 +1646,14 @@ func TestE2E_CrossProtocol_StopReasonMapping(t *testing.T) {
json.NewEncoder(w).Encode(map[string]any{ json.NewEncoder(w).Encode(map[string]any{
"id": "msg_cross_stop", "type": "message", "role": "assistant", "id": "msg_cross_stop", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "被截断的内容..."}}, "content": []map[string]any{{"type": "text", "text": "被截断的内容..."}},
"model": "claude-model", "stop_reason": "max_tokens", "stop_sequence": nil, "model": "claude-model", "stop_reason": "max_tokens", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 10, "output_tokens": 20}, "usage": map[string]any{"input_tokens": 10, "output_tokens": 20},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-model", "model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "长文"}}, "messages": []map[string]any{{"role": "user", "content": "长文"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1659,9 +1671,10 @@ func TestE2E_CrossProtocol_StopReasonMapping(t *testing.T) {
func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) { func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
msgs := reqBody["messages"].([]any) msgs := reqBody["messages"].([]any)
require.GreaterOrEqual(t, len(msgs), 3) require.GreaterOrEqual(t, len(msgs), 3)
toolMsg := msgs[2].(map[string]any) toolMsg := msgs[2].(map[string]any)
@@ -1669,16 +1682,16 @@ func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) {
assert.Equal(t, "call_e2e_001", toolMsg["tool_call_id"]) assert.Equal(t, "call_e2e_001", toolMsg["tool_call_id"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-tr", "object": "chat.completion", "created": 1700000080, "model": "gpt-4o", "id": "chatcmpl-e2e-tr", "object": "chat.completion", "created": 1700000080, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
"message": map[string]any{"role": "assistant", "content": "北京当前晴天温度25°C。"}, "message": map[string]any{"role": "assistant", "content": "北京当前晴天温度25°C。"},
"finish_reason": "stop", "finish_reason": "stop",
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120}, "usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -1722,7 +1735,8 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream_ToolCalls(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
@@ -1730,7 +1744,7 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream_ToolCalls(t *testing.T) {
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-model", "model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"type": "function", "type": "function",
@@ -1817,7 +1831,7 @@ func TestE2E_OpenAI_Upstream5xx_ErrorPassthrough(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1879,7 +1893,8 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"正常\"}}\n\n", "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"正常\"}}\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
@@ -1889,7 +1904,7 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
"stream": true, "stream": true,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
@@ -1902,5 +1917,7 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
assert.Contains(t, respBody, "正常") assert.Contains(t, respBody, "正常")
} }
var _ = fmt.Sprintf var (
var _ = time.Now _ = fmt.Sprintf
_ = time.Now
)

View File

@@ -7,16 +7,17 @@ import (
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"gorm.io/gorm"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/handler" "nex/backend/internal/handler"
"nex/backend/internal/handler/middleware" "nex/backend/internal/handler/middleware"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/internal/service" "nex/backend/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"gorm.io/gorm"
) )
func init() { func init() {
@@ -97,7 +98,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code) assert.Equal(t, 201, w.Code)
var createdModel domain.Model var createdModel domain.Model
json.Unmarshal(w.Body.Bytes(), &createdModel) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &createdModel))
assert.NotEmpty(t, createdModel.ID) assert.NotEmpty(t, createdModel.ID)
// 3. 列出 Provider // 3. 列出 Provider
@@ -106,7 +107,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var providers []domain.Provider var providers []domain.Provider
json.Unmarshal(w.Body.Bytes(), &providers) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &providers))
assert.Len(t, providers, 1) assert.Len(t, providers, 1)
assert.Equal(t, "sk-test-key", providers[0].APIKey) assert.Equal(t, "sk-test-key", providers[0].APIKey)
@@ -116,7 +117,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var models []domain.Model var models []domain.Model
json.Unmarshal(w.Body.Bytes(), &models) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &models))
assert.Len(t, models, 1) assert.Len(t, models, 1)
assert.Equal(t, "gpt-4", models[0].ModelName) assert.Equal(t, "gpt-4", models[0].ModelName)
@@ -163,7 +164,7 @@ func TestAnthropic_ModelCreation(t *testing.T) {
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code) assert.Equal(t, 201, w.Code)
var createdModel domain.Model var createdModel domain.Model
json.Unmarshal(w.Body.Bytes(), &createdModel) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &createdModel))
// 验证创建成功 // 验证创建成功
w = httptest.NewRecorder() w = httptest.NewRecorder()
@@ -194,9 +195,9 @@ func TestStats_RecordingAndQuery(t *testing.T) {
// 直接通过 repository 记录统计(模拟代理请求后的统计记录) // 直接通过 repository 记录统计(模拟代理请求后的统计记录)
statsRepo := repository.NewStatsRepository(db) statsRepo := repository.NewStatsRepository(db)
statsRepo.Record("p1", "gpt-4") require.NoError(t, statsRepo.Record("p1", "gpt-4"))
statsRepo.Record("p1", "gpt-4") require.NoError(t, statsRepo.Record("p1", "gpt-4"))
statsRepo.Record("p1", "gpt-4") require.NoError(t, statsRepo.Record("p1", "gpt-4"))
// 查询统计 // 查询统计
w = httptest.NewRecorder() w = httptest.NewRecorder()
@@ -205,7 +206,7 @@ func TestStats_RecordingAndQuery(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var stats []domain.UsageStats var stats []domain.UsageStats
json.Unmarshal(w.Body.Bytes(), &stats) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &stats))
assert.Len(t, stats, 1) assert.Len(t, stats, 1)
assert.Equal(t, 3, stats[0].RequestCount) assert.Equal(t, 3, stats[0].RequestCount)

View File

@@ -4,11 +4,11 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/config"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"nex/backend/internal/config"
) )
// setupTestDB 创建内存 SQLite 数据库并执行 AutoMigrate。 // setupTestDB 创建内存 SQLite 数据库并执行 AutoMigrate。

View File

@@ -3,9 +3,10 @@ package tests
import ( import (
"testing" "testing"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"nex/backend/internal/config"
) )
func TestMigration_ModelsUUIDPrimaryKey(t *testing.T) { func TestMigration_ModelsUUIDPrimaryKey(t *testing.T) {

View File

@@ -10,9 +10,10 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,9 +10,10 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -11,9 +11,10 @@ package mocks
import ( import (
context "context" context "context"
reflect "reflect"
conversion "nex/backend/internal/conversion" conversion "nex/backend/internal/conversion"
provider "nex/backend/internal/provider" provider "nex/backend/internal/provider"
reflect "reflect"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,9 +10,10 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,9 +10,10 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,9 +10,10 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,10 +10,11 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
time "time" time "time"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,10 +10,11 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
time "time" time "time"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

5
lefthook.yml Normal file
View File

@@ -0,0 +1,5 @@
pre-commit:
commands:
backend-lint:
glob: "backend/**/*.go"
run: cd backend && go tool golangci-lint run --new-from-rev HEAD ./...

View File

@@ -169,15 +169,15 @@
- **WHEN** 应用启动 - **WHEN** 应用启动
- **THEN** SHALL 按以下顺序加载配置: - **THEN** SHALL 按以下顺序加载配置:
1. 解析 CLI 参数(获取 --config 路径) 1. 解析 CLI 参数(获取 --config 路径)
2. 初始化配置管理器 2. 初始化配置管理器
3. 设置默认值 3. 设置默认值
4. 绑定 CLI 参数 4. 绑定 CLI 参数
5. 绑定环境变量 5. 绑定环境变量
6. 读取配置文件 6. 读取配置文件(不存在时自动创建)
7. 反序列化到结构体 7. 反序列化到结构体
8. 验证配置 8. 验证配置
9. 打印配置摘要 9. 打印配置摘要
#### Scenario: 加载失败处理 #### Scenario: 加载失败处理

View File

@@ -31,16 +31,27 @@
- **THEN** SHALL 测试请求转换、响应转换、流式转换 - **THEN** SHALL 测试请求转换、响应转换、流式转换
- **THEN** SHALL 验证转换的准确性和完整性 - **THEN** SHALL 验证转换的准确性和完整性
#### Scenario: config 加载管道集成测试 #### Scenario: LoadConfigFromPath 默认值验证
- **WHEN** 运行 config 加载管道的集成测试 - **WHEN** 运行 config 加载管道的集成测试
- **THEN** SHALL 验证 LoadConfigFromPath 正确加载默认值 - **THEN** SHALL 验证 LoadConfigFromPath 正确加载默认值
- **THEN** SHALL 验证环境变量(`NEX_` 前缀)覆盖默认值
- **THEN** SHALL 验证 YAML 配置文件正确读取 - **THEN** SHALL 验证 YAML 配置文件正确读取
- **THEN** SHALL 验证优先级链CLI 参数 > 环境变量 > YAML 文件 > 默认值 - **THEN** SHALL 验证优先级链CLI 参数 > 环境变量 > YAML 文件 > 默认值
- **THEN** SHALL 验证首次启动自动创建配置文件 - **THEN** SHALL 验证首次启动自动创建配置文件
- **THEN** SHALL 验证 SaveConfig 后重新 LoadConfig 数据一致 - **THEN** SHALL 验证 SaveConfig 后重新 LoadConfig 数据一致
#### Scenario: 环境变量覆盖验证
- **WHEN** 设置 `NEX_SERVER_PORT=9000``NEX_LOG_LEVEL=debug`
- **THEN** SHALL 成功加载
- **THEN** 配置值 SHALL 反映环境变量覆盖
#### Scenario: 自动创建配置文件验证
- **WHEN** 调用 `LoadConfigFromPath` 并指向不存在的文件路径
- **THEN** SHALL 成功加载(不返回 `missing configuration for 'configPath'` 错误)
- **THEN** SHALL 返回默认配置对象
#### Scenario: handler 错误分支测试 #### Scenario: handler 错误分支测试
- **WHEN** 运行 handler 层的单元测试 - **WHEN** 运行 handler 层的单元测试