1
0

feat: 配置 golangci-lint 静态分析并修复存量违规

- 新增 backend/.golangci.yml 配置 12 个 linter(forbidigo、errorlint、errcheck、staticcheck、revive、gocritic、gosec、bodyclose、noctx、nilerr、goimports、gocyclo)
- 新增 lefthook.yml 配置 pre-commit hook 自动运行 lint
- 修复存量代码违规:errors.Is/As 替换、zap.Error 替换、import 排序、errcheck 修复
- 更新 README 补充编码规范说明
- 归档 backend-code-lint 变更
This commit is contained in:
2026-04-24 13:01:48 +08:00
parent 4c78ab6cc8
commit 4c6b49099d
96 changed files with 1290 additions and 1348 deletions

View File

@@ -294,6 +294,9 @@ make frontend-test-coverage # 前端覆盖率
## 开发 ## 开发
```bash ```bash
# 首次克隆后安装 Git hooks
lefthook install
# 顶层便捷命令 # 顶层便捷命令
make dev # 启动开发环境(并行启动后端和前端) make dev # 启动开发环境(并行启动后端和前端)
make build # 构建所有产物 make build # 构建所有产物

91
backend/.golangci.yml Normal file
View File

@@ -0,0 +1,91 @@
run:
timeout: 5m
tests: true
linters:
disable-all: true
enable:
- forbidigo
- errorlint
- errcheck
- staticcheck
- revive
- gocritic
- gosec
- bodyclose
- noctx
- nilerr
- goimports
- gocyclo
linters-settings:
errcheck:
check-blank: true
check-type-assertions: true
exclude-functions:
- fmt.Fprintf
forbidigo:
analyze-types: true
forbid:
- p: '^fmt\.Print.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^fmt\.Fprint.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^log\.(Print|Println|Printf|Fatal|Fatalln|Fatalf|Panic|Panicln|Panicf)$'
msg: 使用 zap logger不要使用标准库 log
- p: '^zap\.L$'
msg: 通过依赖注入传递 *zap.Logger不要使用全局 logger
- p: '^zap\.S$'
msg: 不使用 Sugar logger
revive:
rules:
- name: exported
- name: var-naming
- name: indent-error-flow
- name: error-strings
- name: error-return
- name: blank-imports
- name: context-as-argument
- name: unexported-return
goimports:
local-prefixes: nex/backend
gocyclo:
min-complexity: 10
issues:
exclude-dirs:
- tests/mocks
exclude-generated: true
exclude-rules:
- path: '(_test\.go|tests/)'
linters:
- forbidigo
- path: '(_test\.go|tests/)'
linters:
- errcheck
source: '(^\s*_\s*=|,\s*_)'
- path: 'tests/integration/e2e_conversion_test\.go'
linters:
- errcheck
- path: '(_test\.go|tests/)'
linters:
- revive
text: '^exported:'
- path: '(_test\.go|tests/)'
linters:
- gosec
text: 'G(101|401|501)'
- path: '(_test\.go|tests/)'
linters:
- gocyclo
text: 'cyclomatic complexity (1[1-9]|20) of .* is high \(> 10\)'
- linters:
- revive
text: '(that stutters|BuildUrl should be BuildURL|ConvertHttpRequest should be ConvertHTTPRequest|ConvertHttpResponse should be ConvertHTTPResponse)'
- path: 'internal/conversion/.*\.go'
linters:
- gocyclo
- gocritic
- path: '(internal/provider/client\.go|internal/service/model_service_impl\.go|internal/service/stats_buffer\.go|internal/handler/proxy_handler\.go|cmd/(desktop|server)/main\.go)'
linters:
- gocyclo

View File

@@ -609,6 +609,7 @@ err := v.Validate(myStruct)
- **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节 - **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接 - **字符串拼接**:使用 `strings.Join`,不手写循环拼接
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配 - **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配lint 强约束errorlint 禁止 `err ==` 直接比较和 `err.(*T)` 直接断言)
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()` - **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
- **日志 error 字段**:使用 `zap.Error(err)`,不使用 `zap.String("error", err.Error())` 手工字符串化
- **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片 - **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片

View File

@@ -6,19 +6,25 @@ import (
"fmt" "fmt"
"os/exec" "os/exec"
"strings" "strings"
"go.uber.org/zap"
) )
func showError(title, message string) { func showError(title, message string) {
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`, script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`,
escapeAppleScript(message), escapeAppleScript(title)) escapeAppleScript(message), escapeAppleScript(title))
exec.Command("osascript", "-e", script).Run() if err := exec.Command("osascript", "-e", script).Run(); err != nil {
dialogLogger().Warn("显示错误对话框失败", zap.Error(err))
}
} }
func showAbout() { func showAbout() {
message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway" message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway"
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "关于 Nex Gateway"`, script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "关于 Nex Gateway"`,
escapeAppleScript(message)) escapeAppleScript(message))
exec.Command("osascript", "-e", script).Run() if err := exec.Command("osascript", "-e", script).Run(); err != nil {
dialogLogger().Warn("显示关于对话框失败", zap.Error(err))
}
} }
func escapeAppleScript(s string) string { func escapeAppleScript(s string) string {

View File

@@ -4,7 +4,6 @@ package main
import ( import (
"fmt" "fmt"
"os"
"os/exec" "os/exec"
"sync" "sync"
) )
@@ -63,7 +62,7 @@ func showError(title, message string) {
exec.Command("xmessage", "-center", exec.Command("xmessage", "-center",
fmt.Sprintf("%s: %s", title, message)).Run() fmt.Sprintf("%s: %s", title, message)).Run()
default: default:
fmt.Fprintf(os.Stderr, "错误: %s: %s\n", title, message) dialogLogger().Error("无法显示错误对话框")
} }
} }
@@ -83,6 +82,6 @@ func showAbout() {
exec.Command("xmessage", "-center", exec.Command("xmessage", "-center",
fmt.Sprintf("关于 Nex Gateway: %s", message)).Run() fmt.Sprintf("关于 Nex Gateway: %s", message)).Run()
default: default:
fmt.Fprintf(os.Stderr, "关于 Nex Gateway: %s\n", message) dialogLogger().Info("关于 Nex Gateway")
} }
} }

View File

@@ -0,0 +1,15 @@
package main
import (
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
)
func dialogLogger() *zap.Logger {
if zapLogger != nil {
return zapLogger
}
return pkgLogger.NewMinimal()
}

View File

@@ -13,10 +13,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/getlantern/systray" "nex/embedfs"
"github.com/gin-gonic/gin"
"github.com/gofrs/flock"
"go.uber.org/zap"
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
@@ -28,9 +25,13 @@ import (
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/internal/service" "nex/backend/internal/service"
pkgLogger "nex/backend/pkg/logger"
"nex/embedfs" "github.com/getlantern/systray"
"github.com/gin-gonic/gin"
"github.com/gofrs/flock"
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
) )
var ( var (
@@ -51,12 +52,16 @@ func main() {
showError("Nex Gateway", "已有 Nex 实例运行") showError("Nex Gateway", "已有 Nex 实例运行")
os.Exit(1) os.Exit(1)
} }
defer singleLock.Unlock() defer func() {
if err := singleLock.Unlock(); err != nil {
minimalLogger.Warn("释放实例锁失败", zap.Error(err))
}
}()
if err := checkPortAvailable(port); err != nil { if err := checkPortAvailable(port); err != nil {
minimalLogger.Error("端口不可用", zap.Error(err)) minimalLogger.Error("端口不可用", zap.Error(err))
showError("Nex Gateway", err.Error()) showError("Nex Gateway", err.Error())
os.Exit(1) return
} }
cfg, err := config.LoadConfig() cfg, err := config.LoadConfig()
@@ -75,7 +80,11 @@ func main() {
if err != nil { if err != nil {
minimalLogger.Fatal("初始化日志失败", zap.Error(err)) minimalLogger.Fatal("初始化日志失败", zap.Error(err))
} }
defer zapLogger.Sync() defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
cfg.PrintSummary(zapLogger) cfg.PrintSummary(zapLogger)
@@ -144,14 +153,14 @@ func main() {
go func() { go func() {
zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr)) zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr))
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error())) zapLogger.Fatal("服务器启动失败", zap.Error(err))
} }
}() }()
go func() { go func() {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil { if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
zapLogger.Warn("无法打开浏览器", zap.String("error", err.Error())) zapLogger.Warn("无法打开浏览器", zap.Error(err))
} }
}() }()
@@ -193,7 +202,7 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
func setupStaticFiles(r *gin.Engine) { func setupStaticFiles(r *gin.Engine) {
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist") distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
if err != nil { if err != nil {
zapLogger.Fatal("无法加载前端资源", zap.String("error", err.Error())) zapLogger.Fatal("无法加载前端资源", zap.Error(err))
} }
getContentType := func(path string) string { getContentType := func(path string) string {
@@ -266,7 +275,7 @@ func setupSystray(port int) {
icon, err = embedfs.Assets.ReadFile("assets/icon.png") icon, err = embedfs.Assets.ReadFile("assets/icon.png")
} }
if err != nil { if err != nil {
zapLogger.Error("无法加载托盘图标", zap.String("error", err.Error())) zapLogger.Error("无法加载托盘图标", zap.Error(err))
} }
systray.SetIcon(icon) systray.SetIcon(icon)
systray.SetTitle("Nex Gateway") systray.SetTitle("Nex Gateway")
@@ -287,7 +296,9 @@ func setupSystray(port int) {
for { for {
select { select {
case <-mOpen.ClickedCh: case <-mOpen.ClickedCh:
openBrowser(fmt.Sprintf("http://localhost:%d", port)) if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
zapLogger.Warn("打开浏览器失败", zap.Error(err))
}
case <-mAbout.ClickedCh: case <-mAbout.ClickedCh:
showAbout() showAbout()
case <-mQuit.ClickedCh: case <-mQuit.ClickedCh:
@@ -308,7 +319,9 @@ func doShutdown() {
if server != nil { if server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
server.Shutdown(ctx) if err := server.Shutdown(ctx); err != nil && zapLogger != nil {
zapLogger.Warn("关闭服务器失败", zap.Error(err))
}
} }
if shutdownCancel != nil { if shutdownCancel != nil {
@@ -346,8 +359,8 @@ func (s *SingletonLock) Lock() error {
return nil return nil
} }
func (s *SingletonLock) Unlock() { func (s *SingletonLock) Unlock() error {
s.flock.Unlock() return s.flock.Unlock()
} }
func openBrowser(url string) error { func openBrowser(url string) error {

View File

@@ -21,7 +21,7 @@ func TestCheckPortAvailable(t *testing.T) {
func TestCheckPortOccupied(t *testing.T) { func TestCheckPortOccupied(t *testing.T) {
port := 19827 port := 19827
listener, err := net.Listen("tcp", ":19827") listener, err := net.Listen("tcp", "127.0.0.1:19827")
if err != nil { if err != nil {
t.Fatalf("无法启动测试服务器: %v", err) t.Fatalf("无法启动测试服务器: %v", err)
} }
@@ -47,13 +47,19 @@ func TestCheckPortOccupied(t *testing.T) {
func TestCheckPortAvailableAfterClose(t *testing.T) { func TestCheckPortAvailableAfterClose(t *testing.T) {
port := 19828 port := 19828
listener, err := net.Listen("tcp", ":19828") listener, err := net.Listen("tcp", "127.0.0.1:19828")
if err != nil { if err != nil {
t.Fatalf("无法启动测试服务器: %v", err) t.Fatalf("无法启动测试服务器: %v", err)
} }
server := &http.Server{} server := &http.Server{ReadHeaderTimeout: time.Second}
go server.Serve(listener) defer server.Close()
go func() {
err := server.Serve(listener)
if err != nil && err != http.ErrServerClosed {
t.Errorf("serve failed: %v", err)
}
}()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)

View File

@@ -14,7 +14,11 @@ func TestSingletonLock_FirstLockSuccess(t *testing.T) {
if err := lock.Lock(); err != nil { if err := lock.Lock(); err != nil {
t.Fatalf("首次加锁应成功,但返回错误: %v", err) t.Fatalf("首次加锁应成功,但返回错误: %v", err)
} }
defer lock.Unlock() defer func() {
if err := lock.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
} }
func TestSingletonLock_DuplicateLockFails(t *testing.T) { func TestSingletonLock_DuplicateLockFails(t *testing.T) {
@@ -25,12 +29,18 @@ func TestSingletonLock_DuplicateLockFails(t *testing.T) {
if err := lock1.Lock(); err != nil { if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err) t.Fatalf("首次加锁应成功: %v", err)
} }
defer lock1.Unlock() defer func() {
if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
lock2 := NewSingletonLock(lockPath) lock2 := NewSingletonLock(lockPath)
err := lock2.Lock() err := lock2.Lock()
if err == nil { if err == nil {
lock2.Unlock() if unlockErr := lock2.Unlock(); unlockErr != nil {
t.Fatalf("解锁失败: %v", unlockErr)
}
t.Fatal("重复加锁应失败,但返回 nil") t.Fatal("重复加锁应失败,但返回 nil")
} }
} }
@@ -43,16 +53,22 @@ func TestSingletonLock_UnlockThenRelock(t *testing.T) {
if err := lock1.Lock(); err != nil { if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err) t.Fatalf("首次加锁应成功: %v", err)
} }
lock1.Unlock() if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
lock2 := NewSingletonLock(lockPath) lock2 := NewSingletonLock(lockPath)
if err := lock2.Lock(); err != nil { if err := lock2.Lock(); err != nil {
t.Fatalf("释放后重新加锁应成功: %v", err) t.Fatalf("释放后重新加锁应成功: %v", err)
} }
lock2.Unlock() if err := lock2.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
} }
func TestSingletonLock_UnlockWithoutLock(t *testing.T) { func TestSingletonLock_UnlockWithoutLock(t *testing.T) {
lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock")) lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock"))
lock.Unlock() if err := lock.Unlock(); err != nil {
t.Fatalf("未加锁时解锁失败: %v", err)
}
} }

View File

@@ -6,9 +6,9 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/gin-gonic/gin"
"nex/embedfs" "nex/embedfs"
"github.com/gin-gonic/gin"
) )
func TestSetupStaticFiles(t *testing.T) { func TestSetupStaticFiles(t *testing.T) {

View File

@@ -44,7 +44,11 @@ func main() {
if err != nil { if err != nil {
minimalLogger.Fatal("初始化日志失败", zap.Error(err)) minimalLogger.Fatal("初始化日志失败", zap.Error(err))
} }
defer zapLogger.Sync() defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
cfg.PrintSummary(zapLogger) cfg.PrintSummary(zapLogger)

View File

@@ -1,6 +1,7 @@
package config package config
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@@ -58,7 +59,10 @@ type LogConfig struct {
// DefaultConfig returns default config values // DefaultConfig returns default config values
func DefaultConfig() *Config { func DefaultConfig() *Config {
// Use home dir for default paths // Use home dir for default paths
homeDir, _ := os.UserHomeDir() homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex") nexDir := filepath.Join(homeDir, ".nex")
return &Config{ return &Config{
@@ -97,7 +101,7 @@ func GetConfigDir() (string, error) {
return "", err return "", err
} }
configDir := filepath.Join(homeDir, ".nex") configDir := filepath.Join(homeDir, ".nex")
if err := os.MkdirAll(configDir, 0755); err != nil { if err := os.MkdirAll(configDir, 0o755); err != nil {
return "", err return "", err
} }
return configDir, nil return configDir, nil
@@ -123,7 +127,10 @@ func GetConfigPath() (string, error) {
// setupDefaults 设置默认配置值 // setupDefaults 设置默认配置值
func setupDefaults(v *viper.Viper) { func setupDefaults(v *viper.Viper) {
homeDir, _ := os.UserHomeDir() homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex") nexDir := filepath.Join(homeDir, ".nex")
v.SetDefault("server.port", 9826) v.SetDefault("server.port", 9826)
@@ -177,27 +184,33 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
// 绑定所有 flag 到 viper // 绑定所有 flag 到 viper
// 注意:必须在设置默认值之后绑定 // 注意:必须在设置默认值之后绑定
v.BindPFlag("server.port", flagSet.Lookup("server-port")) bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
v.BindPFlag("server.read_timeout", flagSet.Lookup("server-read-timeout")) bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
v.BindPFlag("server.write_timeout", flagSet.Lookup("server-write-timeout")) bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
v.BindPFlag("database.driver", flagSet.Lookup("database-driver")) bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
v.BindPFlag("database.path", flagSet.Lookup("database-path")) bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
v.BindPFlag("database.host", flagSet.Lookup("database-host")) bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
v.BindPFlag("database.port", flagSet.Lookup("database-port")) bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
v.BindPFlag("database.user", flagSet.Lookup("database-user")) bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
v.BindPFlag("database.password", flagSet.Lookup("database-password")) bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
v.BindPFlag("database.dbname", flagSet.Lookup("database-dbname")) bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
v.BindPFlag("database.max_idle_conns", flagSet.Lookup("database-max-idle-conns")) bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
v.BindPFlag("database.max_open_conns", flagSet.Lookup("database-max-open-conns")) bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
v.BindPFlag("database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime")) bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
v.BindPFlag("log.level", flagSet.Lookup("log-level")) bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
v.BindPFlag("log.path", flagSet.Lookup("log-path")) bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
v.BindPFlag("log.max_size", flagSet.Lookup("log-max-size")) bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
v.BindPFlag("log.max_backups", flagSet.Lookup("log-max-backups")) bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
v.BindPFlag("log.max_age", flagSet.Lookup("log-max-age")) bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
v.BindPFlag("log.compress", flagSet.Lookup("log-compress")) bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
}
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
if err := v.BindPFlag(key, flag); err != nil {
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
}
} }
// setupEnv 绑定环境变量 // setupEnv 绑定环境变量
@@ -218,10 +231,17 @@ func setupConfigFile(v *viper.Viper, configPath string) error {
return appErrors.Wrap(appErrors.ErrInternal, err) return appErrors.Wrap(appErrors.ErrInternal, err)
} }
// 配置文件不存在,创建默认配置文件 // 配置文件不存在,创建默认配置文件
if err := v.SafeWriteConfig(); err != nil { writeErr := v.SafeWriteConfig()
// 忽略写入错误(可能目录已存在等) if writeErr == nil {
return nil return nil
} }
var alreadyExistsErr viper.ConfigFileAlreadyExistsError
if errors.As(writeErr, &alreadyExistsErr) {
return nil
}
return appErrors.Wrap(appErrors.ErrInternal, writeErr)
} }
return nil return nil
} }
@@ -246,7 +266,9 @@ func LoadConfigFromPath(configPath string) (*Config, error) {
setupFlags(v, flagSet) setupFlags(v, flagSet)
// 3. 解析 CLI 参数(忽略错误,因为可能没有参数) // 3. 解析 CLI 参数(忽略错误,因为可能没有参数)
flagSet.Parse(os.Args[1:]) if err := flagSet.Parse(os.Args[1:]); err != nil {
return nil, appErrors.Wrap(appErrors.ErrInvalidRequest, err)
}
// 4. 获取配置文件路径(可能被 --config 参数覆盖) // 4. 获取配置文件路径(可能被 --config 参数覆盖)
if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" { if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" {
@@ -295,11 +317,11 @@ func SaveConfig(cfg *Config) error {
// Ensure directory exists // Ensure directory exists
dir := filepath.Dir(configPath) dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil { if err := os.MkdirAll(dir, 0o755); err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err) return appErrors.Wrap(appErrors.ErrInternal, err)
} }
return os.WriteFile(configPath, data, 0600) return os.WriteFile(configPath, data, 0o600)
} }
// Validate validates the config // Validate validates the config

View File

@@ -236,7 +236,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
configPath := filepath.Join(dir, "config.yaml") configPath := filepath.Join(dir, "config.yaml")
data, err := yaml.Marshal(cfg) data, err := yaml.Marshal(cfg)
require.NoError(t, err) require.NoError(t, err)
err = os.WriteFile(configPath, data, 0644) err = os.WriteFile(configPath, data, 0o600)
require.NoError(t, err) require.NoError(t, err)
// 加载配置 // 加载配置

View File

@@ -6,15 +6,15 @@ import (
// Provider 供应商模型 // Provider 供应商模型
type Provider struct { type Provider struct {
ID string `gorm:"primaryKey" json:"id"` ID string `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"` Name string `gorm:"not null" json:"name"`
APIKey string `gorm:"not null" json:"api_key"` APIKey string `gorm:"not null" json:"api_key"`
BaseURL string `gorm:"not null" json:"base_url"` BaseURL string `gorm:"not null" json:"base_url"`
Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"` Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"`
Enabled bool `gorm:"default:true" json:"enabled"` Enabled bool `gorm:"default:true" json:"enabled"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"` Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
} }
// Model 模型配置id 为 UUID 自动生成UNIQUE(provider_id, model_name) // Model 模型配置id 为 UUID 自动生成UNIQUE(provider_id, model_name)
@@ -47,4 +47,3 @@ func (Model) TableName() string {
func (UsageStats) TableName() string { func (UsageStats) TableName() string {
return "usage_stats" return "usage_stats"
} }

View File

@@ -141,7 +141,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Message: err.Message, Message: err.Message,
}, },
} }
body, _ := json.Marshal(errMsg) body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"type":"error","error":{"type":"internal_error","message":"internal error"}}`), statusCode
}
return body, statusCode return body, statusCode
} }
@@ -235,7 +238,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
return "", nil, err return "", nil, err
} }
rewriteFunc := func(newModel string) ([]byte, error) { rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
} }
return current, rewriteFunc, nil return current, rewriteFunc, nil
@@ -269,7 +276,11 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
switch ifaceType { switch ifaceType {
case conversion.InterfaceTypeChat: case conversion.InterfaceTypeChat:
// Chat 响应必须有 model 字段,存在则改写,不存在则添加 // Chat 响应必须有 model 字段,存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
default: default:
return body, nil return body, nil

View File

@@ -2,6 +2,7 @@ package anthropic
import ( import (
"encoding/json" "encoding/json"
"errors"
"testing" "testing"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
@@ -52,10 +53,10 @@ func TestAdapter_BuildUrl(t *testing.T) {
a := NewAdapter() a := NewAdapter()
tests := []struct { tests := []struct {
name string name string
nativePath string nativePath string
interfaceType conversion.InterfaceType interfaceType conversion.InterfaceType
expected string expected string
}{ }{
{"聊天", "/v1/messages", conversion.InterfaceTypeChat, "/v1/messages"}, {"聊天", "/v1/messages", conversion.InterfaceTypeChat, "/v1/messages"},
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"}, {"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
@@ -102,9 +103,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
a := NewAdapter() a := NewAdapter()
tests := []struct { tests := []struct {
name string name string
interfaceType conversion.InterfaceType interfaceType conversion.InterfaceType
expected bool expected bool
}{ }{
{"聊天", conversion.InterfaceTypeChat, true}, {"聊天", conversion.InterfaceTypeChat, true},
{"模型", conversion.InterfaceTypeModels, true}, {"模型", conversion.InterfaceTypeModels, true},
@@ -141,8 +142,8 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
t.Run("解码嵌入请求", func(t *testing.T) { t.Run("解码嵌入请求", func(t *testing.T) {
_, err := a.DecodeEmbeddingRequest([]byte(`{}`)) _, err := a.DecodeEmbeddingRequest([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
@@ -150,24 +151,24 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3") provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider) _, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.True(t, errors.As(err, &convErr))
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("解码嵌入响应", func(t *testing.T) { t.Run("解码嵌入响应", func(t *testing.T) {
_, err := a.DecodeEmbeddingResponse([]byte(`{}`)) _, err := a.DecodeEmbeddingResponse([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("编码嵌入响应", func(t *testing.T) { t.Run("编码嵌入响应", func(t *testing.T) {
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{}) _, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
} }
@@ -178,8 +179,8 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
t.Run("解码重排序请求", func(t *testing.T) { t.Run("解码重排序请求", func(t *testing.T) {
_, err := a.DecodeRerankRequest([]byte(`{}`)) _, err := a.DecodeRerankRequest([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
@@ -187,24 +188,24 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3") provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider) _, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("解码重排序响应", func(t *testing.T) { t.Run("解码重排序响应", func(t *testing.T) {
_, err := a.DecodeRerankResponse([]byte(`{}`)) _, err := a.DecodeRerankResponse([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("编码重排序响应", func(t *testing.T) { t.Run("编码重排序响应", func(t *testing.T) {
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{}) _, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
} }

View File

@@ -28,7 +28,10 @@ func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
var canonicalMsgs []canonical.CanonicalMessage var canonicalMsgs []canonical.CanonicalMessage
for _, msg := range req.Messages { for _, msg := range req.Messages {
decoded := decodeMessage(msg) decoded, err := decodeMessage(msg)
if err != nil {
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析消息内容失败").WithCause(err)
}
canonicalMsgs = append(canonicalMsgs, decoded...) canonicalMsgs = append(canonicalMsgs, decoded...)
} }
@@ -94,10 +97,13 @@ func decodeSystem(system any) any {
} }
// decodeMessage 解码 Anthropic 消息 // decodeMessage 解码 Anthropic 消息
func decodeMessage(msg Message) []canonical.CanonicalMessage { func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
switch msg.Role { switch msg.Role {
case "user": case "user":
blocks := decodeContentBlocks(msg.Content) blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
var toolResults []canonical.ContentBlock var toolResults []canonical.ContentBlock
var others []canonical.ContentBlock var others []canonical.ContentBlock
for _, b := range blocks { for _, b := range blocks {
@@ -117,58 +123,83 @@ func decodeMessage(msg Message) []canonical.CanonicalMessage {
if len(result) == 0 { if len(result) == 0 {
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}}) result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
} }
return result return result, nil
case "assistant": case "assistant":
blocks := decodeContentBlocks(msg.Content) blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
if len(blocks) == 0 { if len(blocks) == 0 {
blocks = append(blocks, canonical.NewTextBlock("")) blocks = append(blocks, canonical.NewTextBlock(""))
} }
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}} return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
} }
return nil return nil, nil
} }
// decodeContentBlocks 解码内容块列表 // decodeContentBlocks 解码内容块列表
func decodeContentBlocks(content any) []canonical.ContentBlock { func decodeContentBlocks(content any) ([]canonical.ContentBlock, error) {
switch v := content.(type) { switch v := content.(type) {
case string: case string:
return []canonical.ContentBlock{canonical.NewTextBlock(v)} return []canonical.ContentBlock{canonical.NewTextBlock(v)}, nil
case []any: case []any:
var blocks []canonical.ContentBlock var blocks []canonical.ContentBlock
for _, item := range v { for _, item := range v {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
block := decodeSingleContentBlock(m) block, err := decodeSingleContentBlock(m)
if err != nil {
return nil, err
}
if block != nil { if block != nil {
blocks = append(blocks, *block) blocks = append(blocks, *block)
} }
} }
} }
if len(blocks) > 0 { if len(blocks) > 0 {
return blocks return blocks, nil
} }
return []canonical.ContentBlock{canonical.NewTextBlock("")} return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
case nil: case nil:
return []canonical.ContentBlock{canonical.NewTextBlock("")} return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
default: default:
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))} return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}, nil
} }
} }
// decodeSingleContentBlock 解码单个内容块 // decodeSingleContentBlock 解码单个内容块
func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock { func decodeSingleContentBlock(m map[string]any) (*canonical.ContentBlock, error) {
t, _ := m["type"].(string) t, ok := m["type"].(string)
if !ok {
return nil, nil
}
switch t { switch t {
case "text": case "text":
text, _ := m["text"].(string) text, ok := m["text"].(string)
return &canonical.ContentBlock{Type: "text", Text: text} if !ok {
text = ""
}
return &canonical.ContentBlock{Type: "text", Text: text}, nil
case "tool_use": case "tool_use":
id, _ := m["id"].(string) id, ok := m["id"].(string)
name, _ := m["name"].(string) if !ok {
input, _ := json.Marshal(m["input"]) id = ""
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input} }
name, ok := m["name"].(string)
if !ok {
name = ""
}
input, err := json.Marshal(m["input"])
if err != nil {
return nil, err
}
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}, nil
case "tool_result": case "tool_result":
toolUseID, _ := m["tool_use_id"].(string) toolUseID, ok := m["tool_use_id"].(string)
if !ok {
toolUseID = ""
}
isErr := false isErr := false
if ie, ok := m["is_error"].(bool); ok { if ie, ok := m["is_error"].(bool); ok {
isErr = ie isErr = ie
@@ -179,7 +210,11 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
case string: case string:
content = json.RawMessage(fmt.Sprintf("%q", cv)) content = json.RawMessage(fmt.Sprintf("%q", cv))
default: default:
content, _ = json.Marshal(cv) encoded, err := json.Marshal(cv)
if err != nil {
return nil, err
}
content = encoded
} }
} else { } else {
content = json.RawMessage(`""`) content = json.RawMessage(`""`)
@@ -189,15 +224,18 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
ToolUseID: toolUseID, ToolUseID: toolUseID,
Content: content, Content: content,
IsError: &isErr, IsError: &isErr,
} }, nil
case "thinking": case "thinking":
thinking, _ := m["thinking"].(string) thinking, ok := m["thinking"].(string)
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking} if !ok {
thinking = ""
}
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}, nil
case "redacted_thinking": case "redacted_thinking":
// 丢弃 // 丢弃
return nil return nil, nil
} }
return nil return nil, nil
} }
// decodeTools 解码工具定义 // decodeTools 解码工具定义
@@ -232,7 +270,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
} }
case map[string]any: case map[string]any:
t, _ := v["type"].(string) t, ok := v["type"].(string)
if !ok {
return nil
}
switch t { switch t {
case "auto": case "auto":
return canonical.NewToolChoiceAuto() return canonical.NewToolChoiceAuto()
@@ -241,7 +282,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
case "any": case "any":
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
case "tool": case "tool":
name, _ := v["name"].(string) name, ok := v["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name) return canonical.NewToolChoiceNamed(name)
} }
} }

View File

@@ -182,7 +182,7 @@ func encodeContentBlocks(blocks []canonical.ContentBlock) []map[string]any {
result = append(result, m) result = append(result, m)
case "tool_result": case "tool_result":
m := map[string]any{ m := map[string]any{
"type": "tool_result", "type": "tool_result",
"tool_use_id": b.ToolUseID, "tool_use_id": b.ToolUseID,
} }
if b.Content != nil { if b.Content != nil {
@@ -335,11 +335,11 @@ func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
} }
result := map[string]any{ result := map[string]any{
"id": resp.ID, "id": resp.ID,
"type": "message", "type": "message",
"role": "assistant", "role": "assistant",
"model": resp.Model, "model": resp.Model,
"content": blocks, "content": blocks,
"stop_reason": sr, "stop_reason": sr,
"stop_sequence": nil, "stop_sequence": nil,
"usage": usage, "usage": usage,

View File

@@ -33,7 +33,8 @@ func TestEncodeRequest_Basic(t *testing.T) {
assert.Equal(t, true, result["stream"]) assert.Equal(t, true, result["stream"])
assert.Equal(t, float64(1024), result["max_tokens"]) assert.Equal(t, float64(1024), result["max_tokens"])
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1) assert.Len(t, msgs, 1)
} }
@@ -55,17 +56,20 @@ func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
// tool 消息应被合并到相邻 user 消息 // tool 消息应被合并到相邻 user 消息
foundToolResult := false foundToolResult := false
for _, m := range msgs { for _, m := range msgs {
msgMap := m.(map[string]any) msgMap, ok := m.(map[string]any)
require.True(t, ok)
if msgMap["role"] == "user" { if msgMap["role"] == "user" {
content, ok := msgMap["content"].([]any) content, ok := msgMap["content"].([]any)
if ok { if ok {
for _, c := range content { for _, c := range content {
block := c.(map[string]any) block, ok := c.(map[string]any)
require.True(t, ok)
if block["type"] == "tool_result" { if block["type"] == "tool_result" {
foundToolResult = true foundToolResult = true
} }
@@ -93,8 +97,10 @@ func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
firstMsg := msgs[0].(map[string]any) require.True(t, ok)
firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", firstMsg["role"]) assert.Equal(t, "user", firstMsg["role"])
} }
@@ -140,9 +146,11 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "assistant", result["role"]) assert.Equal(t, "assistant", result["role"])
assert.Equal(t, "end_turn", result["stop_reason"]) assert.Equal(t, "end_turn", result["stop_reason"])
content := result["content"].([]any) content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1) assert.Len(t, content, 1)
block := content[0].(map[string]any) block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", block["type"]) assert.Equal(t, "text", block["type"])
assert.Equal(t, "你好", block["text"]) assert.Equal(t, "你好", block["text"])
} }
@@ -160,10 +168,12 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
data := result["data"].([]any) data, ok := result["data"].([]any)
require.True(t, ok)
assert.Len(t, data, 1) assert.Len(t, data, 1)
model := data[0].(map[string]any) model, ok := data[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "claude-3-opus", model["id"]) assert.Equal(t, "claude-3-opus", model["id"])
// created 应为 RFC3339 格式 // created 应为 RFC3339 格式
createdAt, ok := model["created_at"].(string) createdAt, ok := model["created_at"].(string)
@@ -280,11 +290,14 @@ func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1) assert.Len(t, msgs, 1)
userMsg := msgs[0].(map[string]any) userMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", userMsg["role"]) assert.Equal(t, "user", userMsg["role"])
content := userMsg["content"].([]any) content, ok := userMsg["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 2) assert.Len(t, content, 2)
} }
@@ -321,7 +334,8 @@ func TestEncodeResponse_ReasoningTokens(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any) usage, ok := result["usage"].(map[string]any)
require.True(t, ok)
_, hasReasoning := usage["reasoning_tokens"] _, hasReasoning := usage["reasoning_tokens"]
assert.False(t, hasReasoning) assert.False(t, hasReasoning)
} }
@@ -341,9 +355,11 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
content := result["content"].([]any) content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1) assert.Len(t, content, 1)
block := content[0].(map[string]any) block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", block["type"]) assert.Equal(t, "tool_use", block["type"])
assert.Equal(t, "tool_1", block["id"]) assert.Equal(t, "tool_1", block["id"])
assert.Equal(t, "search", block["name"]) assert.Equal(t, "search", block["name"])

View File

@@ -28,7 +28,7 @@ func NewStreamDecoder() *StreamDecoder {
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
data := rawChunk data := rawChunk
if len(d.utf8Remainder) > 0 { if len(d.utf8Remainder) > 0 {
data = append(d.utf8Remainder, rawChunk...) data = append(append([]byte{}, d.utf8Remainder...), rawChunk...)
d.utf8Remainder = nil d.utf8Remainder = nil
} }
@@ -50,9 +50,10 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
for _, line := range strings.Split(text, "\n") { for _, line := range strings.Split(text, "\n") {
line = strings.TrimRight(line, "\r") line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "event: ") { switch {
case strings.HasPrefix(line, "event: "):
eventType = strings.TrimPrefix(line, "event: ") eventType = strings.TrimPrefix(line, "event: ")
} else if strings.HasPrefix(line, "data: ") { case strings.HasPrefix(line, "data: "):
eventData = strings.TrimPrefix(line, "data: ") eventData = strings.TrimPrefix(line, "data: ")
if eventType != "" && eventData != "" { if eventType != "" && eventData != "" {
chunkEvents := d.processEvent(eventType, []byte(eventData)) chunkEvents := d.processEvent(eventType, []byte(eventData))
@@ -60,8 +61,8 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
} }
eventType = "" eventType = ""
eventData = "" eventData = ""
} else if line == "" { case line == "":
// SSE 事件分隔符 continue
} }
} }
@@ -135,7 +136,7 @@ func (d *StreamDecoder) processMessageStart(data []byte) []canonical.CanonicalSt
// processContentBlockStart 处理内容块开始事件 // processContentBlockStart 处理内容块开始事件
func (d *StreamDecoder) processContentBlockStart(data []byte) []canonical.CanonicalStreamEvent { func (d *StreamDecoder) processContentBlockStart(data []byte) []canonical.CanonicalStreamEvent {
var raw struct { var raw struct {
Index int `json:"index"` Index int `json:"index"`
ContentBlock struct { ContentBlock struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text"` Text string `json:"text"`

View File

@@ -47,23 +47,23 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
checkValue string checkValue string
}{ }{
{ {
name: "text_delta", name: "text_delta",
deltaType: "text_delta", deltaType: "text_delta",
deltaData: map[string]any{"type": "text_delta", "text": "你好"}, deltaData: map[string]any{"type": "text_delta", "text": "你好"},
checkField: "text", checkField: "text",
checkValue: "你好", checkValue: "你好",
}, },
{ {
name: "input_json_delta", name: "input_json_delta",
deltaType: "input_json_delta", deltaType: "input_json_delta",
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"}, deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
checkField: "partial_json", checkField: "partial_json",
checkValue: "{\"key\":", checkValue: "{\"key\":",
}, },
{ {
name: "thinking_delta", name: "thinking_delta",
deltaType: "thinking_delta", deltaType: "thinking_delta",
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"}, deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
checkField: "thinking", checkField: "thinking",
checkValue: "思考中", checkValue: "思考中",
}, },
@@ -74,7 +74,7 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
payload := map[string]any{ payload := map[string]any{
"type": "content_block_delta", "type": "content_block_delta",
"index": 0, "index": 0,
"delta": tt.deltaData, "delta": tt.deltaData,
} }
raw := makeAnthropicEvent("content_block_delta", payload) raw := makeAnthropicEvent("content_block_delta", payload)
@@ -298,7 +298,7 @@ func TestStreamDecoder_WebSearchToolResult_Suppressed(t *testing.T) {
"type": "content_block_start", "type": "content_block_start",
"index": 3, "index": 3,
"content_block": map[string]any{ "content_block": map[string]any{
"type": "web_search_tool_result", "type": "web_search_tool_result",
"tool_use_id": "search_1", "tool_use_id": "search_1",
}, },
} }
@@ -331,8 +331,8 @@ func TestStreamDecoder_CitationsDelta_Discarded(t *testing.T) {
"type": "content_block_delta", "type": "content_block_delta",
"index": 0, "index": 0,
"delta": map[string]any{ "delta": map[string]any{
"type": "citations_delta", "type": "citations_delta",
"citation": map[string]any{"title": "ref1"}, "citation": map[string]any{"title": "ref1"},
}, },
} }
raw := makeAnthropicEvent("content_block_delta", payload) raw := makeAnthropicEvent("content_block_delta", payload)
@@ -466,7 +466,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
}, },
} }
deltaPayload1 := map[string]any{ deltaPayload1 := map[string]any{
"type": "message_delta", "type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn"}, "delta": map[string]any{"stop_reason": "end_turn"},
"usage": map[string]any{"output_tokens": 25}, "usage": map[string]any{"output_tokens": 25},
} }
@@ -478,7 +478,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
assert.Equal(t, 25, events[0].Usage.OutputTokens) assert.Equal(t, 25, events[0].Usage.OutputTokens)
deltaPayload2 := map[string]any{ deltaPayload2 := map[string]any{
"type": "message_delta", "type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn"}, "delta": map[string]any{"stop_reason": "end_turn"},
"usage": map[string]any{"output_tokens": 30}, "usage": map[string]any{"output_tokens": 30},
} }

View File

@@ -80,7 +80,8 @@ func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
break break
} }
} }
cb := payload["content_block"].(map[string]any) cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", cb["type"]) assert.Equal(t, "text", cb["type"])
} }
@@ -107,7 +108,8 @@ func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
break break
} }
} }
cb := payload["content_block"].(map[string]any) cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", cb["type"]) assert.Equal(t, "tool_use", cb["type"])
assert.Equal(t, "toolu_1", cb["id"]) assert.Equal(t, "toolu_1", cb["id"])
assert.Equal(t, "search", cb["name"]) assert.Equal(t, "search", cb["name"])
@@ -131,7 +133,8 @@ func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
break break
} }
} }
cb := payload["content_block"].(map[string]any) cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "thinking", cb["type"]) assert.Equal(t, "thinking", cb["type"])
} }
@@ -173,7 +176,8 @@ func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
break break
} }
} }
delta := payload["delta"].(map[string]any) delta, okd := payload["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "end_turn", delta["stop_reason"]) assert.Equal(t, "end_turn", delta["stop_reason"])
} }
@@ -199,7 +203,8 @@ func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
break break
} }
} }
u := payload["usage"].(map[string]any) u, oku := payload["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(88), u["output_tokens"]) assert.Equal(t, float64(88), u["output_tokens"])
} }

View File

@@ -173,13 +173,15 @@ func TestDecodeMessage_UserWithOnlyToolResults(t *testing.T) {
} }
func TestDecodeContentBlocks_Nil(t *testing.T) { func TestDecodeContentBlocks_Nil(t *testing.T) {
blocks := decodeContentBlocks(nil) blocks, err := decodeContentBlocks(nil)
require.NoError(t, err)
assert.Len(t, blocks, 1) assert.Len(t, blocks, 1)
assert.Equal(t, "", blocks[0].Text) assert.Equal(t, "", blocks[0].Text)
} }
func TestDecodeContentBlocks_String(t *testing.T) { func TestDecodeContentBlocks_String(t *testing.T) {
blocks := decodeContentBlocks("hello") blocks, err := decodeContentBlocks("hello")
require.NoError(t, err)
assert.Len(t, blocks, 1) assert.Len(t, blocks, 1)
assert.Equal(t, "hello", blocks[0].Text) assert.Equal(t, "hello", blocks[0].Text)
} }
@@ -217,8 +219,10 @@ func TestEncodeToolChoice(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := encodeToolChoice(tt.choice) result := encodeToolChoice(tt.choice)
assert.Equal(t, tt.want["type"], result.(map[string]any)["type"]) r, ok := result.(map[string]any)
assert.Equal(t, tt.want["name"], result.(map[string]any)["name"]) require.True(t, ok)
assert.Equal(t, tt.want["type"], r["type"])
assert.Equal(t, tt.want["name"], r["name"])
}) })
} }
} }
@@ -315,12 +319,15 @@ func TestEncodeRequest_WithTools(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
tools := result["tools"].([]any) tools, okt := result["tools"].([]any)
require.True(t, okt)
assert.Len(t, tools, 1) assert.Len(t, tools, 1)
tool := tools[0].(map[string]any) tool, okt2 := tools[0].(map[string]any)
require.True(t, okt2)
assert.Equal(t, "search", tool["name"]) assert.Equal(t, "search", tool["name"])
assert.Equal(t, "Search things", tool["description"]) assert.Equal(t, "Search things", tool["description"])
tc := result["tool_choice"].(map[string]any) tc, oktc := result["tool_choice"].(map[string]any)
require.True(t, oktc)
assert.Equal(t, "auto", tc["type"]) assert.Equal(t, "auto", tc["type"])
} }
@@ -354,9 +361,9 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")}, Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
StopReason: &sr, StopReason: &sr,
Usage: canonical.CanonicalUsage{ Usage: canonical.CanonicalUsage{
InputTokens: 100, InputTokens: 100,
OutputTokens: 50, OutputTokens: 50,
CacheReadTokens: &cacheRead, CacheReadTokens: &cacheRead,
CacheCreationTokens: &cacheCreation, CacheCreationTokens: &cacheCreation,
}, },
} }
@@ -366,7 +373,8 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any) usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["input_tokens"]) assert.Equal(t, float64(100), usage["input_tokens"])
assert.Equal(t, float64(30), usage["cache_read_input_tokens"]) assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"]) assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])

View File

@@ -6,22 +6,22 @@ import (
// MessagesRequest Anthropic Messages 请求 // MessagesRequest Anthropic Messages 请求
type MessagesRequest struct { type MessagesRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
System any `json:"system,omitempty"` System any `json:"system,omitempty"`
MaxTokens int `json:"max_tokens"` MaxTokens int `json:"max_tokens"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"` TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"` Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
Metadata *RequestMetadata `json:"metadata,omitempty"` Metadata *RequestMetadata `json:"metadata,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"` Thinking *ThinkingConfig `json:"thinking,omitempty"`
OutputConfig *OutputConfig `json:"output_config,omitempty"` OutputConfig *OutputConfig `json:"output_config,omitempty"`
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"` DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
Container any `json:"container,omitempty"` Container any `json:"container,omitempty"`
} }
// RequestMetadata 请求元数据 // RequestMetadata 请求元数据
@@ -122,8 +122,8 @@ type ContentBlock struct {
// ResponseUsage 响应用量 // ResponseUsage 响应用量
type ResponseUsage struct { type ResponseUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"` CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"`
CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"` CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"`
} }

View File

@@ -38,8 +38,8 @@ type CanonicalEmbeddingResponse struct {
// EmbeddingData 嵌入数据项 // EmbeddingData 嵌入数据项
type EmbeddingData struct { type EmbeddingData struct {
Index int `json:"index"` Index int `json:"index"`
Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串 Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串
} }
// EmbeddingUsage 嵌入用量 // EmbeddingUsage 嵌入用量

View File

@@ -18,17 +18,17 @@ const (
type DeltaType string type DeltaType string
const ( const (
DeltaTypeText DeltaType = "text_delta" DeltaTypeText DeltaType = "text_delta"
DeltaTypeInputJSON DeltaType = "input_json_delta" DeltaTypeInputJSON DeltaType = "input_json_delta"
DeltaTypeThinking DeltaType = "thinking_delta" DeltaTypeThinking DeltaType = "thinking_delta"
) )
// StreamDelta 流式增量联合体 // StreamDelta 流式增量联合体
type StreamDelta struct { type StreamDelta struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"` PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"` Thinking string `json:"thinking,omitempty"`
} }
// StreamContentBlock 流式内容块联合体 // StreamContentBlock 流式内容块联合体
@@ -48,12 +48,12 @@ type CanonicalStreamEvent struct {
Message *StreamMessage `json:"message,omitempty"` Message *StreamMessage `json:"message,omitempty"`
// ContentBlockStartEvent / ContentBlockDeltaEvent / ContentBlockStopEvent // ContentBlockStartEvent / ContentBlockDeltaEvent / ContentBlockStopEvent
Index *int `json:"index,omitempty"` Index *int `json:"index,omitempty"`
ContentBlock *StreamContentBlock `json:"content_block,omitempty"` ContentBlock *StreamContentBlock `json:"content_block,omitempty"`
Delta *StreamDelta `json:"delta,omitempty"` Delta *StreamDelta `json:"delta,omitempty"`
// MessageDeltaEvent // MessageDeltaEvent
StopReason *StopReason `json:"stop_reason,omitempty"` StopReason *StopReason `json:"stop_reason,omitempty"`
Usage *CanonicalUsage `json:"usage,omitempty"` Usage *CanonicalUsage `json:"usage,omitempty"`
// ErrorEvent // ErrorEvent

View File

@@ -40,8 +40,8 @@ type ContentBlock struct {
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
// ToolUseBlock // ToolUseBlock
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Input json.RawMessage `json:"input,omitempty"` Input json.RawMessage `json:"input,omitempty"`
// ToolResultBlock // ToolResultBlock
@@ -138,43 +138,43 @@ type ThinkingConfig struct {
// OutputFormat 输出格式联合体 // OutputFormat 输出格式联合体
type OutputFormat struct { type OutputFormat struct {
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Schema json.RawMessage `json:"schema,omitempty"` Schema json.RawMessage `json:"schema,omitempty"`
Strict *bool `json:"strict,omitempty"` Strict *bool `json:"strict,omitempty"`
} }
// CanonicalRequest 规范请求 // CanonicalRequest 规范请求
type CanonicalRequest struct { type CanonicalRequest struct {
Model string `json:"model"` Model string `json:"model"`
System any `json:"system,omitempty"` // nil, string, or []SystemBlock System any `json:"system,omitempty"` // nil, string, or []SystemBlock
Messages []CanonicalMessage `json:"messages"` Messages []CanonicalMessage `json:"messages"`
Tools []CanonicalTool `json:"tools,omitempty"` Tools []CanonicalTool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"` ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Parameters RequestParameters `json:"parameters"` Parameters RequestParameters `json:"parameters"`
Thinking *ThinkingConfig `json:"thinking,omitempty"` Thinking *ThinkingConfig `json:"thinking,omitempty"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
UserID string `json:"user_id,omitempty"` UserID string `json:"user_id,omitempty"`
OutputFormat *OutputFormat `json:"output_format,omitempty"` OutputFormat *OutputFormat `json:"output_format,omitempty"`
ParallelToolUse *bool `json:"parallel_tool_use,omitempty"` ParallelToolUse *bool `json:"parallel_tool_use,omitempty"`
} }
// CanonicalUsage 规范用量 // CanonicalUsage 规范用量
type CanonicalUsage struct { type CanonicalUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
CacheReadTokens *int `json:"cache_read_tokens,omitempty"` CacheReadTokens *int `json:"cache_read_tokens,omitempty"`
CacheCreationTokens *int `json:"cache_creation_tokens,omitempty"` CacheCreationTokens *int `json:"cache_creation_tokens,omitempty"`
ReasoningTokens *int `json:"reasoning_tokens,omitempty"` ReasoningTokens *int `json:"reasoning_tokens,omitempty"`
} }
// CanonicalResponse 规范响应 // CanonicalResponse 规范响应
type CanonicalResponse struct { type CanonicalResponse struct {
ID string `json:"id"` ID string `json:"id"`
Model string `json:"model"` Model string `json:"model"`
Content []ContentBlock `json:"content"` Content []ContentBlock `json:"content"`
StopReason *StopReason `json:"stop_reason,omitempty"` StopReason *StopReason `json:"stop_reason,omitempty"`
Usage CanonicalUsage `json:"usage"` Usage CanonicalUsage `json:"usage"`
} }
// GetSystemString 获取系统消息字符串 // GetSystemString 获取系统消息字符串

View File

@@ -10,9 +10,9 @@ import (
func TestGetSystemString(t *testing.T) { func TestGetSystemString(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
system any system any
want string want string
}{ }{
{"string", "hello", "hello"}, {"string", "hello", "hello"},
{"nil", nil, ""}, {"nil", nil, ""},
@@ -97,11 +97,11 @@ func TestCanonicalRequest_RoundTrip(t *testing.T) {
func TestCanonicalResponse_RoundTrip(t *testing.T) { func TestCanonicalResponse_RoundTrip(t *testing.T) {
sr := StopReasonEndTurn sr := StopReasonEndTurn
resp := &CanonicalResponse{ resp := &CanonicalResponse{
ID: "resp-1", ID: "resp-1",
Model: "gpt-4", Model: "gpt-4",
Content: []ContentBlock{NewTextBlock("hello")}, Content: []ContentBlock{NewTextBlock("hello")},
StopReason: &sr, StopReason: &sr,
Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5}, Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5},
} }
data, err := json.Marshal(resp) data, err := json.Marshal(resp)

View File

@@ -114,7 +114,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
} }
interfaceType := clientAdapter.DetectInterfaceType(nativePath) interfaceType := clientAdapter.DetectInterfaceType(nativePath)
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType) providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
providerHeaders := providerAdapter.BuildHeaders(provider) providerHeaders := providerAdapter.BuildHeaders(provider)
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body) providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
if err != nil { if err != nil {
@@ -122,7 +122,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
} }
return &HTTPRequestSpec{ return &HTTPRequestSpec{
URL: provider.BaseURL + providerUrl, URL: provider.BaseURL + providerURL,
Method: spec.Method, Method: spec.Method,
Headers: providerHeaders, Headers: providerHeaders,
Body: providerBody, Body: providerBody,
@@ -134,24 +134,21 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
if e.IsPassthrough(clientProtocol, providerProtocol) { if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议时最小化改写 model 字段 // Smart Passthrough: 同协议时最小化改写 model 字段
if modelOverride != "" && len(spec.Body) > 0 { if modelOverride != "" && len(spec.Body) > 0 {
adapter, err := e.registry.Get(clientProtocol) adapter, getErr := e.registry.Get(clientProtocol)
if err != nil { if getErr == nil {
return &spec, nil rewrittenBody, rewriteErr := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
} if rewriteErr != nil {
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
if err != nil {
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体", e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
zap.Error(err), zap.Error(rewriteErr),
zap.String("interface", string(interfaceType))) zap.String("interface", string(interfaceType)))
return &spec, nil } else {
return &HTTPResponseSpec{
StatusCode: spec.StatusCode,
Headers: spec.Headers,
Body: rewrittenBody,
}, nil
}
} }
return &HTTPResponseSpec{
StatusCode: spec.StatusCode,
Headers: spec.Headers,
Body: rewrittenBody,
}, nil
} }
return &spec, nil return &spec, nil
} }
@@ -182,11 +179,10 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
if e.IsPassthrough(clientProtocol, providerProtocol) { if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段 // Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
if modelOverride != "" { if modelOverride != "" {
adapter, err := e.registry.Get(clientProtocol) adapter, getErr := e.registry.Get(clientProtocol)
if err != nil { if getErr == nil {
return NewPassthroughStreamConverter(), nil return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
} }
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
} }
return NewPassthroughStreamConverter(), nil return NewPassthroughStreamConverter(), nil
} }
@@ -201,9 +197,9 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
} }
ctx := ConversionContext{ ctx := ConversionContext{
ConversionID: uuid.New().String(), ConversionID: uuid.New().String(),
InterfaceType: InterfaceTypeChat, InterfaceType: InterfaceTypeChat,
Timestamp: time.Now(), Timestamp: time.Now(),
} }
return NewCanonicalStreamConverterWithMiddleware( return NewCanonicalStreamConverterWithMiddleware(
@@ -306,7 +302,7 @@ func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapte
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
models, err := providerAdapter.DecodeModelsResponse(body) models, err := providerAdapter.DecodeModelsResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
encoded, err := clientAdapter.EncodeModelsResponse(models) encoded, err := clientAdapter.EncodeModelsResponse(models)
@@ -320,12 +316,12 @@ func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdap
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
info, err := providerAdapter.DecodeModelInfoResponse(body) info, err := providerAdapter.DecodeModelInfoResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
encoded, err := clientAdapter.EncodeModelInfoResponse(info) encoded, err := clientAdapter.EncodeModelInfoResponse(info)
if err != nil { if err != nil {
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
return encoded, nil return encoded, nil
@@ -334,7 +330,7 @@ func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeEmbeddingRequest(body) req, err := clientAdapter.DecodeEmbeddingRequest(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error())) e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.Error(err))
return body, nil return body, nil
} }
return providerAdapter.EncodeEmbeddingRequest(req, provider) return providerAdapter.EncodeEmbeddingRequest(req, provider)
@@ -343,7 +339,7 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeEmbeddingResponse(body) resp, err := providerAdapter.DecodeEmbeddingResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
if modelOverride != "" { if modelOverride != "" {
@@ -355,21 +351,22 @@ func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeRerankRequest(body) req, err := clientAdapter.DecodeRerankRequest(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error())) e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.Error(err))
return body, nil return body, nil
} }
return providerAdapter.EncodeRerankRequest(req, provider) return providerAdapter.EncodeRerankRequest(req, provider)
} }
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeRerankResponse(body) resp, decodeErr := providerAdapter.DecodeRerankResponse(body)
if err != nil { if decodeErr == nil {
return body, nil if modelOverride != "" {
resp.Model = modelOverride
}
return clientAdapter.EncodeRerankResponse(resp)
} }
if modelOverride != "" {
resp.Model = modelOverride return body, nil
}
return clientAdapter.EncodeRerankResponse(resp)
} }
// DetectInterfaceType 检测接口类型 // DetectInterfaceType 检测接口类型
@@ -391,8 +388,12 @@ func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol stri
"type": "internal_error", "type": "internal_error",
}, },
} }
body, _ := json.Marshal(fallback) body, marshalErr := json.Marshal(fallback)
return body, 500, nil if marshalErr == nil {
return body, 500, nil
}
return []byte(`{"error":{"message":"internal error","type":"internal_error"}}`), 500, nil
} }
body, statusCode := adapter.EncodeError(err) body, statusCode := adapter.EncodeError(err)
return body, statusCode, nil return body, statusCode, nil

View File

@@ -38,8 +38,8 @@ func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
} }
} }
func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName } func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName }
func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" } func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" }
func (m *mockProtocolAdapter) SupportsPassthrough() bool { return m.passthrough } func (m *mockProtocolAdapter) SupportsPassthrough() bool { return m.passthrough }
func (m *mockProtocolAdapter) DetectInterfaceType(nativePath string) InterfaceType { func (m *mockProtocolAdapter) DetectInterfaceType(nativePath string) InterfaceType {
@@ -190,14 +190,16 @@ func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel str
// noopStreamDecoder 空流式解码器 // noopStreamDecoder 空流式解码器
type noopStreamDecoder struct{} type noopStreamDecoder struct{}
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil } func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil } return nil
}
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
// noopStreamEncoder 空流式编码器 // noopStreamEncoder 空流式编码器
type noopStreamEncoder struct{} type noopStreamEncoder struct{}
func (e *noopStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { return nil } func (e *noopStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { return nil }
func (e *noopStreamEncoder) Flush() [][]byte { return nil } func (e *noopStreamEncoder) Flush() [][]byte { return nil }
// ============ 测试用例 ============ // ============ 测试用例 ============
@@ -615,6 +617,7 @@ func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.Canonical
} }
return nil return nil
} }
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent { func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
if d.flushFn != nil { if d.flushFn != nil {
return d.flushFn() return d.flushFn()
@@ -634,6 +637,7 @@ func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEve
} }
return nil return nil
} }
func (e *engineTestStreamEncoder) Flush() [][]byte { func (e *engineTestStreamEncoder) Flush() [][]byte {
if e.flushFn != nil { if e.flushFn != nil {
return e.flushFn() return e.flushFn()

View File

@@ -6,17 +6,17 @@ import "fmt"
type ErrorCode string type ErrorCode string
const ( const (
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT" ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD" ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD"
ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE" ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE"
ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE" ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE"
ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR" ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR"
ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR" ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR"
ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR" ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR"
ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR" ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR"
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION" ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE" ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED" ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
) )
// ConversionError 协议转换错误 // ConversionError 协议转换错误

View File

@@ -4,10 +4,10 @@ package conversion
type InterfaceType string type InterfaceType string
const ( const (
InterfaceTypeChat InterfaceType = "CHAT" InterfaceTypeChat InterfaceType = "CHAT"
InterfaceTypeModels InterfaceType = "MODELS" InterfaceTypeModels InterfaceType = "MODELS"
InterfaceTypeModelInfo InterfaceType = "MODEL_INFO" InterfaceTypeModelInfo InterfaceType = "MODEL_INFO"
InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS" InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS"
InterfaceTypeRerank InterfaceType = "RERANK" InterfaceTypeRerank InterfaceType = "RERANK"
InterfaceTypePassthrough InterfaceType = "PASSTHROUGH" InterfaceTypePassthrough InterfaceType = "PASSTHROUGH"
) )

View File

@@ -138,7 +138,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Code: string(err.Code), Code: string(err.Code),
}, },
} }
body, _ := json.Marshal(errMsg) body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"error":{"message":"internal error","type":"internal_error","code":"INTERNAL_ERROR"}}`), statusCode
}
return body, statusCode return body, statusCode
} }
@@ -248,7 +251,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
return "", nil, err return "", nil, err
} }
rewriteFunc := func(newModel string) ([]byte, error) { rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
} }
return current, rewriteFunc, nil return current, rewriteFunc, nil
@@ -282,12 +289,20 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
switch ifaceType { switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings: case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加 // Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
case conversion.InterfaceTypeRerank: case conversion.InterfaceTypeRerank:
// Rerank 响应:存在 model 字段则改写,不存在则不添加 // Rerank 响应:存在 model 字段则改写,不存在则不添加
if _, exists := m["model"]; exists { if _, exists := m["model"]; exists {
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
} }
return json.Marshal(m) return json.Marshal(m)
default: default:

View File

@@ -48,10 +48,10 @@ func TestAdapter_BuildUrl(t *testing.T) {
a := NewAdapter() a := NewAdapter()
tests := []struct { tests := []struct {
name string name string
nativePath string nativePath string
interfaceType conversion.InterfaceType interfaceType conversion.InterfaceType
expected string expected string
}{ }{
{"聊天", "/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"}, {"聊天", "/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
{"模型", "/models", conversion.InterfaceTypeModels, "/models"}, {"模型", "/models", conversion.InterfaceTypeModels, "/models"},
@@ -92,9 +92,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
a := NewAdapter() a := NewAdapter()
tests := []struct { tests := []struct {
name string name string
interfaceType conversion.InterfaceType interfaceType conversion.InterfaceType
expected bool expected bool
}{ }{
{"聊天", conversion.InterfaceTypeChat, true}, {"聊天", conversion.InterfaceTypeChat, true},
{"模型", conversion.InterfaceTypeModels, true}, {"模型", conversion.InterfaceTypeModels, true},

View File

@@ -215,10 +215,16 @@ func decodeUserContent(content any) []canonical.ContentBlock {
var blocks []canonical.ContentBlock var blocks []canonical.ContentBlock
for _, item := range v { for _, item := range v {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string) t, ok := m["type"].(string)
if !ok {
continue
}
switch t { switch t {
case "text": case "text":
text, _ := m["text"].(string) text, ok := m["text"].(string)
if !ok {
text = ""
}
blocks = append(blocks, canonical.NewTextBlock(text)) blocks = append(blocks, canonical.NewTextBlock(text))
case "image_url": case "image_url":
blocks = append(blocks, canonical.ContentBlock{Type: "image"}) blocks = append(blocks, canonical.ContentBlock{Type: "image"})
@@ -242,9 +248,9 @@ func decodeUserContent(content any) []canonical.ContentBlock {
// contentPart 内容部分 // contentPart 内容部分
type contentPart struct { type contentPart struct {
Type string Type string
Text string Text string
Refusal string Refusal string
} }
// decodeContentParts 解码内容部分 // decodeContentParts 解码内容部分
@@ -256,13 +262,22 @@ func decodeContentParts(content any) []contentPart {
var result []contentPart var result []contentPart
for _, item := range parts { for _, item := range parts {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string) t, ok := m["type"].(string)
if !ok {
continue
}
switch t { switch t {
case "text": case "text":
text, _ := m["text"].(string) text, ok := m["text"].(string)
if !ok {
text = ""
}
result = append(result, contentPart{Type: "text", Text: text}) result = append(result, contentPart{Type: "text", Text: text})
case "refusal": case "refusal":
refusal, _ := m["refusal"].(string) refusal, ok := m["refusal"].(string)
if !ok {
refusal = ""
}
result = append(result, contentPart{Type: "refusal", Refusal: refusal}) result = append(result, contentPart{Type: "refusal", Refusal: refusal})
} }
} }
@@ -307,21 +322,33 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
} }
case map[string]any: case map[string]any:
t, _ := v["type"].(string) t, ok := v["type"].(string)
if !ok {
return nil
}
switch t { switch t {
case "function": case "function":
if fn, ok := v["function"].(map[string]any); ok { if fn, ok := v["function"].(map[string]any); ok {
name, _ := fn["name"].(string) name, ok := fn["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name) return canonical.NewToolChoiceNamed(name)
} }
case "custom": case "custom":
if custom, ok := v["custom"].(map[string]any); ok { if custom, ok := v["custom"].(map[string]any); ok {
name, _ := custom["name"].(string) name, ok := custom["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name) return canonical.NewToolChoiceNamed(name)
} }
case "allowed_tools": case "allowed_tools":
if at, ok := v["allowed_tools"].(map[string]any); ok { if at, ok := v["allowed_tools"].(map[string]any); ok {
mode, _ := at["mode"].(string) mode, ok := at["mode"].(string)
if !ok {
mode = ""
}
if mode == "required" { if mode == "required" {
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
} }
@@ -443,7 +470,7 @@ func decodeDeprecatedFields(req *ChatCompletionRequest) {
case map[string]any: case map[string]any:
if name, ok := v["name"].(string); ok { if name, ok := v["name"].(string); ok {
req.ToolChoice = map[string]any{ req.ToolChoice = map[string]any{
"type": "function", "type": "function",
"function": map[string]any{"name": name}, "function": map[string]any{"name": name},
} }
} }

View File

@@ -450,7 +450,7 @@ func encodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte
"object": "list", "object": "list",
"data": data, "data": data,
"model": resp.Model, "model": resp.Model,
"usage": resp.Usage, "usage": resp.Usage,
}) })
} }

View File

@@ -45,9 +45,11 @@ func TestEncodeRequest_SystemInjection(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 2) assert.Len(t, msgs, 2)
firstMsg := msgs[0].(map[string]any) firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "system", firstMsg["role"]) assert.Equal(t, "system", firstMsg["role"])
assert.Equal(t, "你是助手", firstMsg["content"]) assert.Equal(t, "你是助手", firstMsg["content"])
} }
@@ -72,12 +74,15 @@ func TestEncodeRequest_ToolCalls(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
assistantMsg := msgs[0].(map[string]any) require.True(t, ok)
assistantMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
toolCalls, ok := assistantMsg["tool_calls"].([]any) toolCalls, ok := assistantMsg["tool_calls"].([]any)
require.True(t, ok) require.True(t, ok)
assert.Len(t, toolCalls, 1) assert.Len(t, toolCalls, 1)
tc := toolCalls[0].(map[string]any) tc, ok := toolCalls[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "call_1", tc["id"]) assert.Equal(t, "call_1", tc["id"])
} }
@@ -100,11 +105,11 @@ func TestEncodeRequest_Thinking(t *testing.T) {
func TestEncodeResponse_Basic(t *testing.T) { func TestEncodeResponse_Basic(t *testing.T) {
sr := canonical.StopReasonEndTurn sr := canonical.StopReasonEndTurn
resp := &canonical.CanonicalResponse{ resp := &canonical.CanonicalResponse{
ID: "resp-1", ID: "resp-1",
Model: "gpt-4", Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")}, Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
StopReason: &sr, StopReason: &sr,
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5}, Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
} }
body, err := encodeResponse(resp) body, err := encodeResponse(resp)
@@ -115,9 +120,12 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "resp-1", result["id"]) assert.Equal(t, "resp-1", result["id"])
assert.Equal(t, "chat.completion", result["object"]) assert.Equal(t, "chat.completion", result["object"])
choices := result["choices"].([]any) choices, ok := result["choices"].([]any)
choice := choices[0].(map[string]any) require.True(t, ok)
msg := choice["message"].(map[string]any) choice, ok := choices[0].(map[string]any)
require.True(t, ok)
msg, ok := choice["message"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "你好", msg["content"]) assert.Equal(t, "你好", msg["content"])
assert.Equal(t, "stop", choice["finish_reason"]) assert.Equal(t, "stop", choice["finish_reason"])
} }
@@ -126,9 +134,9 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
sr := canonical.StopReasonToolUse sr := canonical.StopReasonToolUse
input := json.RawMessage(`{"q":"test"}`) input := json.RawMessage(`{"q":"test"}`)
resp := &canonical.CanonicalResponse{ resp := &canonical.CanonicalResponse{
ID: "resp-2", ID: "resp-2",
Model: "gpt-4", Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)}, Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
StopReason: &sr, StopReason: &sr,
} }
@@ -137,8 +145,12 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any) choices, okc := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any) require.True(t, okc)
msgMap, okm := choices[0].(map[string]any)
require.True(t, okm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
tcs, ok := msg["tool_calls"].([]any) tcs, ok := msg["tool_calls"].([]any)
require.True(t, ok) require.True(t, ok)
assert.Len(t, tcs, 1) assert.Len(t, tcs, 1)
@@ -158,7 +170,8 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "list", result["object"]) assert.Equal(t, "list", result["object"])
data := result["data"].([]any) data, okd := result["data"].([]any)
require.True(t, okd)
assert.Len(t, data, 2) assert.Len(t, data, 2)
} }
@@ -317,8 +330,12 @@ func TestEncodeResponse_Thinking(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any) choices, okch := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any) require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
assert.Equal(t, "回答", msg["content"]) assert.Equal(t, "回答", msg["content"])
assert.Equal(t, "思考过程", msg["reasoning_content"]) assert.Equal(t, "思考过程", msg["reasoning_content"])
} }

View File

@@ -18,9 +18,9 @@ func TestStreamDecoder_BasicText(t *testing.T) {
d := NewStreamDecoder() d := NewStreamDecoder()
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -56,8 +56,8 @@ func TestStreamDecoder_ToolCalls(t *testing.T) {
idx := 0 idx := 0
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -98,8 +98,8 @@ func TestStreamDecoder_Thinking(t *testing.T) {
d := NewStreamDecoder() d := NewStreamDecoder()
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -127,8 +127,8 @@ func TestStreamDecoder_FinishReason(t *testing.T) {
d := NewStreamDecoder() d := NewStreamDecoder()
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -161,8 +161,8 @@ func TestStreamDecoder_DoneSignal(t *testing.T) {
// 先发送一个文本 chunk // 先发送一个文本 chunk
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -190,8 +190,8 @@ func TestStreamDecoder_RefusalReuse(t *testing.T) {
// 连续两个 refusal delta chunk // 连续两个 refusal delta chunk
for _, text := range []string{"拒绝", "原因"} { for _, text := range []string{"拒绝", "原因"} {
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -250,8 +250,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
idx0 := 0 idx0 := 0
chunk1 := map[string]any{ chunk1 := map[string]any{
"id": "chatcmpl-mt", "id": "chatcmpl-mt",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -274,8 +274,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
idx1 := 1 idx1 := 1
chunk2 := map[string]any{ chunk2 := map[string]any{
"id": "chatcmpl-mt", "id": "chatcmpl-mt",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -322,8 +322,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
d := NewStreamDecoder() d := NewStreamDecoder()
chunk1 := map[string]any{ chunk1 := map[string]any{
"id": "chatcmpl-multi", "id": "chatcmpl-multi",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -332,8 +332,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
}, },
} }
chunk2 := map[string]any{ chunk2 := map[string]any{
"id": "chatcmpl-multi", "id": "chatcmpl-multi",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -358,8 +358,8 @@ func TestStreamDecoder_UTF8Truncation(t *testing.T) {
d := NewStreamDecoder() d := NewStreamDecoder()
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-utf8", "id": "chatcmpl-utf8",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -390,8 +390,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
idx := 0 idx := 0
chunk1 := map[string]any{ chunk1 := map[string]any{
"id": "chatcmpl-tc", "id": "chatcmpl-tc",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -412,8 +412,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
}, },
} }
chunk2 := map[string]any{ chunk2 := map[string]any{
"id": "chatcmpl-tc", "id": "chatcmpl-tc",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,

View File

@@ -10,9 +10,9 @@ import (
// StreamEncoder OpenAI 流式编码器 // StreamEncoder OpenAI 流式编码器
type StreamEncoder struct { type StreamEncoder struct {
bufferedStart *canonical.CanonicalStreamEvent bufferedStart *canonical.CanonicalStreamEvent
toolCallIndexMap map[string]int toolCallIndexMap map[string]int
nextToolCallIndex int nextToolCallIndex int
} }
// NewStreamEncoder 创建 OpenAI 流式编码器 // NewStreamEncoder 创建 OpenAI 流式编码器
@@ -195,8 +195,8 @@ func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent)
func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte { func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte {
chunk := map[string]any{ chunk := map[string]any{
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
"delta": delta, "delta": delta,
}}, }},
} }
return e.marshalChunk(chunk) return e.marshalChunk(chunk)

View File

@@ -27,8 +27,12 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
data := strings.TrimPrefix(s, "data: ") data := strings.TrimPrefix(s, "data: ")
data = strings.TrimRight(data, "\n") data = strings.TrimRight(data, "\n")
require.NoError(t, json.Unmarshal([]byte(data), &payload)) require.NoError(t, json.Unmarshal([]byte(data), &payload))
choices := payload["choices"].([]any) choices, okch := payload["choices"].([]any)
delta := choices[0].(map[string]any)["delta"].(map[string]any) require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
delta, okd := msgMap["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "assistant", delta["role"]) assert.Equal(t, "assistant", delta["role"])
} }

View File

@@ -177,7 +177,8 @@ func TestEncodeRerankResponse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "rerank-1", result["model"]) assert.Equal(t, "rerank-1", result["model"])
results := result["results"].([]any) results, okr := result["results"].([]any)
require.True(t, okr)
assert.Len(t, results, 1) assert.Len(t, results, 1)
} }
@@ -356,9 +357,9 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
reasoning := 20 reasoning := 20
sr := canonical.StopReasonEndTurn sr := canonical.StopReasonEndTurn
resp := &canonical.CanonicalResponse{ resp := &canonical.CanonicalResponse{
ID: "r1", ID: "r1",
Model: "gpt-4", Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")}, Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
StopReason: &sr, StopReason: &sr,
Usage: canonical.CanonicalUsage{ Usage: canonical.CanonicalUsage{
InputTokens: 100, InputTokens: 100,
@@ -373,7 +374,8 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any) usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["prompt_tokens"]) assert.Equal(t, float64(100), usage["prompt_tokens"])
ptd, ok := usage["prompt_tokens_details"].(map[string]any) ptd, ok := usage["prompt_tokens_details"].(map[string]any)
require.True(t, ok) require.True(t, ok)
@@ -412,8 +414,10 @@ func TestEncodeResponse_StopReasons(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any) choices, okch := result["choices"].([]any)
choice := choices[0].(map[string]any) require.True(t, okch)
choice, okc := choices[0].(map[string]any)
require.True(t, okc)
assert.Equal(t, tt.want, choice["finish_reason"]) assert.Equal(t, tt.want, choice["finish_reason"])
}) })
} }

View File

@@ -4,42 +4,42 @@ import "encoding/json"
// ChatCompletionRequest OpenAI Chat Completion 请求 // ChatCompletionRequest OpenAI Chat Completion 请求
type ChatCompletionRequest struct { type ChatCompletionRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
Tools []Tool `json:"tools,omitempty"` Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"` MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"` PresencePenalty *float64 `json:"presence_penalty,omitempty"`
Stop any `json:"stop,omitempty"` Stop any `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` ReasoningEffort string `json:"reasoning_effort,omitempty"`
N *int `json:"n,omitempty"` N *int `json:"n,omitempty"`
Seed *int `json:"seed,omitempty"` Seed *int `json:"seed,omitempty"`
Logprobs *bool `json:"logprobs,omitempty"` Logprobs *bool `json:"logprobs,omitempty"`
TopLogprobs *int `json:"top_logprobs,omitempty"` TopLogprobs *int `json:"top_logprobs,omitempty"`
// 已废弃字段 // 已废弃字段
Functions []FunctionDef `json:"functions,omitempty"` Functions []FunctionDef `json:"functions,omitempty"`
FunctionCall any `json:"function_call,omitempty"` FunctionCall any `json:"function_call,omitempty"`
} }
// Message OpenAI 消息 // Message OpenAI 消息
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content any `json:"content"` Content any `json:"content"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"`
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"` ReasoningContent string `json:"reasoning_content,omitempty"`
// 已废弃 // 已废弃
FunctionCall *FunctionCallMsg `json:"function_call,omitempty"` FunctionCall *FunctionCallMsg `json:"function_call,omitempty"`
@@ -88,8 +88,8 @@ type FunctionDef struct {
// ResponseFormat OpenAI 响应格式 // ResponseFormat OpenAI 响应格式
type ResponseFormat struct { type ResponseFormat struct {
Type string `json:"type"` Type string `json:"type"`
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"` JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
} }
// JSONSchemaDef JSON Schema 定义 // JSONSchemaDef JSON Schema 定义
@@ -118,7 +118,7 @@ type ChatCompletionResponse struct {
// Choice OpenAI 选择项 // Choice OpenAI 选择项
type Choice struct { type Choice struct {
Index int `json:"index"` Index int `json:"index"`
Message *Message `json:"message,omitempty"` Message *Message `json:"message,omitempty"`
Delta *Message `json:"delta,omitempty"` Delta *Message `json:"delta,omitempty"`
FinishReason *string `json:"finish_reason"` FinishReason *string `json:"finish_reason"`
@@ -127,10 +127,10 @@ type Choice struct {
// Usage OpenAI 用量 // Usage OpenAI 用量
type Usage struct { type Usage struct {
PromptTokens int `json:"prompt_tokens"` PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"` CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"` PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
} }

View File

@@ -61,7 +61,7 @@ func initDB(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error)
return gorm.Open(mysql.Open(dsn), gormConfig) return gorm.Open(mysql.Open(dsn), gormConfig)
default: default:
dbDir := filepath.Dir(cfg.Path) dbDir := filepath.Dir(cfg.Path)
if err := os.MkdirAll(dbDir, 0755); err != nil { if err := os.MkdirAll(dbDir, 0o755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err) return nil, fmt.Errorf("创建数据库目录失败: %w", err)
} }
if zapLogger != nil { if zapLogger != nil {
@@ -95,7 +95,9 @@ func runMigrations(db *gorm.DB, driver string, zapLogger *zap.Logger) error {
zap.String("dir", migrationsSubDir)) zap.String("dir", migrationsSubDir))
} }
goose.SetDialect(gooseDialect) if err := goose.SetDialect(gooseDialect); err != nil {
return err
}
if err := goose.Up(sqlDB, migrationsDir); err != nil { if err := goose.Up(sqlDB, migrationsDir); err != nil {
return err return err
} }

View File

@@ -4,11 +4,11 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/config"
) )
func TestInit_SQLite(t *testing.T) { func TestInit_SQLite(t *testing.T) {

View File

@@ -13,4 +13,3 @@ type Provider struct {
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
} }

View File

@@ -6,13 +6,13 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
) )
func TestProviderHandler_CreateProvider_Success(t *testing.T) { func TestProviderHandler_CreateProvider_Success(t *testing.T) {
@@ -24,9 +24,9 @@ func TestProviderHandler_CreateProvider_Success(t *testing.T) {
h := NewProviderHandler(mockSvc) h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{ body, _ := json.Marshal(map[string]string{
"id": "p1", "id": "p1",
"name": "Test", "name": "Test",
"api_key": "sk-test", "api_key": "sk-test",
"base_url": "https://api.test.com", "base_url": "https://api.test.com",
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()

View File

@@ -9,23 +9,22 @@ import (
"strings" "strings"
"testing" "testing"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"gorm.io/gorm" "gorm.io/gorm"
"nex/backend/internal/domain"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks"
) )
func init() { func init() {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
} }
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) { func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()

View File

@@ -20,7 +20,6 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
if id, ok := requestID.(string); ok { if id, ok := requestID.(string); ok {
requestIDStr = id requestIDStr = id
} }
logger.Info("请求开始", logger.Info("请求开始",
pkglogger.Method(c.Request.Method), pkglogger.Method(c.Request.Method),
pkglogger.Path(path), pkglogger.Path(path),

View File

@@ -4,13 +4,13 @@ import (
"errors" "errors"
"net/http" "net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
) )
// ModelHandler 模型管理处理器 // ModelHandler 模型管理处理器
@@ -58,16 +58,16 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
err := h.modelService.Create(model) err := h.modelService.Create(model)
if err != nil { if err != nil {
if err == appErrors.ErrProviderNotFound { if errors.Is(err, appErrors.ErrProviderNotFound) {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在", "error": "供应商不存在",
}) })
return return
} }
if err == appErrors.ErrDuplicateModel { if errors.Is(err, appErrors.ErrDuplicateModel) {
c.JSON(http.StatusConflict, gin.H{ c.JSON(http.StatusConflict, gin.H{
"error": "同一供应商下模型名称已存在", "error": "同一供应商下模型名称已存在",
"code": appErrors.ErrDuplicateModel.Code, "code": appErrors.ErrDuplicateModel.Code,
}) })
return return
} }
@@ -101,7 +101,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
model, err := h.modelService.Get(id) model, err := h.modelService.Get(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到", "error": "模型未找到",
}) })
@@ -166,7 +166,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
err := h.modelService.Delete(id) err := h.modelService.Delete(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到", "error": "模型未找到",
}) })

View File

@@ -4,13 +4,13 @@ import (
"errors" "errors"
"net/http" "net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
) )
// ProviderHandler 供应商管理处理器 // ProviderHandler 供应商管理处理器
@@ -55,7 +55,7 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
err := h.providerService.Create(provider) err := h.providerService.Create(provider)
if err != nil { if err != nil {
if err == appErrors.ErrInvalidProviderID { if errors.Is(err, appErrors.ErrInvalidProviderID) {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"error": appErrors.ErrInvalidProviderID.Message, "error": appErrors.ErrInvalidProviderID.Message,
"code": appErrors.ErrInvalidProviderID.Code, "code": appErrors.ErrInvalidProviderID.Code,
@@ -86,7 +86,7 @@ func (h *ProviderHandler) GetProvider(c *gin.Context) {
provider, err := h.providerService.Get(id) provider, err := h.providerService.Get(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到", "error": "供应商未找到",
}) })
@@ -113,7 +113,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
err := h.providerService.Update(id, req) err := h.providerService.Update(id, req)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到", "error": "供应商未找到",
}) })
@@ -145,7 +145,7 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
err := h.providerService.Delete(id) err := h.providerService.Delete(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到", "error": "供应商未找到",
}) })

View File

@@ -3,30 +3,32 @@ package handler
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"errors"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical" "nex/backend/internal/conversion/canonical"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/service" "nex/backend/internal/service"
"nex/backend/pkg/modelid" "nex/backend/pkg/modelid"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger" pkglogger "nex/backend/pkg/logger"
) )
// ProxyHandler 统一代理处理器 // ProxyHandler 统一代理处理器
type ProxyHandler struct { type ProxyHandler struct {
engine *conversion.ConversionEngine engine *conversion.ConversionEngine
client provider.ProviderClient client provider.ProviderClient
routingService service.RoutingService routingService service.RoutingService
providerService service.ProviderService providerService service.ProviderService
statsService service.StatsService statsService service.StatsService
logger *zap.Logger logger *zap.Logger
} }
// NewProxyHandler 创建统一代理处理器 // NewProxyHandler 创建统一代理处理器
@@ -138,7 +140,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
targetProvider := conversion.NewTargetProvider( targetProvider := conversion.NewTargetProvider(
routeResult.Provider.BaseURL, routeResult.Provider.BaseURL,
routeResult.Provider.APIKey, routeResult.Provider.APIKey,
routeResult.Model.ModelName, // 上游模型名,用于请求改写 routeResult.Model.ModelName, // 上游模型名,用于请求改写
) )
// 判断是否流式 // 判断是否流式
@@ -159,7 +161,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 转换请求 // 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil { if err != nil {
h.logger.Error("转换请求失败", zap.String("error", err.Error())) h.logger.Error("转换请求失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
} }
@@ -167,7 +169,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 发送请求 // 发送请求
resp, err := h.client.Send(c.Request.Context(), *outSpec) resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil { if err != nil {
h.logger.Error("发送请求失败", zap.String("error", err.Error())) h.logger.Error("发送请求失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
} }
@@ -175,7 +177,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 转换响应,传入 modelOverride跨协议场景覆写 model 字段) // 转换响应,传入 modelOverride跨协议场景覆写 model 字段)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID) convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
if err != nil { if err != nil {
h.logger.Error("转换响应失败", zap.String("error", err.Error())) h.logger.Error("转换响应失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
} }
@@ -191,7 +193,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body) c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
go func() { go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}() }()
} }
@@ -226,34 +228,50 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
for event := range eventChan { for event := range eventChan {
if event.Error != nil { if event.Error != nil {
h.logger.Error("流读取错误", zap.String("error", event.Error.Error())) h.logger.Error("流读取错误", zap.Error(event.Error))
break break
} }
if event.Done { if event.Done {
// flush 转换器 // flush 转换器
chunks := streamConverter.Flush() chunks := streamConverter.Flush()
for _, chunk := range chunks { if err := h.writeStreamChunks(writer, chunks); err != nil {
writer.Write(chunk) h.logger.Warn("流式响应写回失败", zap.Error(err))
writer.Flush()
} }
break break
} }
chunks := streamConverter.ProcessChunk(event.Data) chunks := streamConverter.ProcessChunk(event.Data)
for _, chunk := range chunks { if err := h.writeStreamChunks(writer, chunks); err != nil {
writer.Write(chunk) h.logger.Warn("流式响应写回失败", zap.Error(err))
writer.Flush() break
} }
} }
go func() { go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}() }()
} }
func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error {
for _, chunk := range chunks {
if _, err := writer.Write(chunk); err != nil {
return err
}
if err := writer.Flush(); err != nil {
return err
}
}
return nil
}
// isStreamRequest 判断是否流式请求 // isStreamRequest 判断是否流式请求
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool { func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol) ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol)
if err != nil {
return false
}
if ifaceType != conversion.InterfaceTypeChat { if ifaceType != conversion.InterfaceTypeChat {
return false return false
} }
@@ -272,7 +290,7 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
// 从数据库查询所有启用的模型 // 从数据库查询所有启用的模型
models, err := h.providerService.ListEnabledModels() models, err := h.providerService.ListEnabledModels()
if err != nil { if err != nil {
h.logger.Error("查询启用模型失败", zap.String("error", err.Error())) h.logger.Error("查询启用模型失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
return return
} }
@@ -294,7 +312,7 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
// 使用 adapter 编码返回 // 使用 adapter 编码返回
body, err := adapter.EncodeModelsResponse(modelList) body, err := adapter.EncodeModelsResponse(modelList)
if err != nil { if err != nil {
h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error())) h.logger.Error("编码 Models 响应失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
return return
} }
@@ -342,8 +360,13 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
// writeConversionError 写入转换错误 // writeConversionError 写入转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) { func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
if convErr, ok := err.(*conversion.ConversionError); ok { var convErr *conversion.ConversionError
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol) if errors.As(err, &convErr) {
body, statusCode, encodeErr := h.engine.EncodeError(convErr, clientProtocol)
if encodeErr != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": encodeErr.Error()})
return
}
c.Data(statusCode, "application/json", body) c.Data(statusCode, "application/json", body)
return return
} }

View File

@@ -8,27 +8,26 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks"
) )
func init() { func init() {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
} }
func setupProxyEngine(t *testing.T) *conversion.ConversionEngine { func setupProxyEngine(t *testing.T) *conversion.ConversionEngine {
t.Helper() t.Helper()
registry := conversion.NewMemoryRegistry() registry := conversion.NewMemoryRegistry()
@@ -844,7 +843,8 @@ func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
require.True(t, ok) require.True(t, ok)
assert.Len(t, data, 2) assert.Len(t, data, 2)
first := data[0].(map[string]interface{}) first, ok2 := data[0].(map[string]interface{})
require.True(t, ok2)
assert.Equal(t, "openai/gpt-4", first["id"]) assert.Equal(t, "openai/gpt-4", first["id"])
} }
@@ -918,7 +918,7 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) { client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
var req map[string]interface{} var req map[string]interface{}
json.Unmarshal(spec.Body, &req) require.NoError(t, json.Unmarshal(spec.Body, &req))
assert.Equal(t, "gpt-4", req["model"]) assert.Equal(t, "gpt-4", req["model"])
return &conversion.HTTPResponseSpec{ return &conversion.HTTPResponseSpec{

View File

@@ -5,9 +5,9 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/gin-gonic/gin"
"nex/backend/internal/service" "nex/backend/internal/service"
"github.com/gin-gonic/gin"
) )
// StatsHandler 统计处理器 // StatsHandler 统计处理器

View File

@@ -51,6 +51,7 @@ type Client struct {
} }
// ProviderClient 供应商客户端接口 // ProviderClient 供应商客户端接口
//
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks //go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
type ProviderClient interface { type ProviderClient interface {
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
@@ -141,7 +142,10 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
defer resp.Body.Close() defer resp.Body.Close()
cancel() cancel()
errBody, _ := io.ReadAll(resp.Body) errBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, fmt.Errorf("供应商返回错误: HTTP %d读取错误响应失败: %w", resp.StatusCode, readErr)
}
if len(errBody) > 0 { if len(errBody) > 0 {
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody)) return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody))
} }
@@ -184,7 +188,7 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
if isNetworkError(err) { if isNetworkError(err) {
c.logger.Error("流网络错误", zap.String("error", err.Error())) c.logger.Error("流网络错误", zap.Error(err))
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)} eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
} else { } else {
c.logger.Error("流读取错误", zap.Error(err)) c.logger.Error("流读取错误", zap.Error(err))

View File

@@ -41,7 +41,8 @@ func TestClient_Send_Success(t *testing.T) {
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"test","model":"gpt-4"}`)) _, err := w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
@@ -65,7 +66,8 @@ func TestClient_Send_Success(t *testing.T) {
func TestClient_Send_ErrorResponse(t *testing.T) { func TestClient_Send_ErrorResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":{"message":"Invalid API key"}}`)) _, err := w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
@@ -140,12 +142,15 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n")) _, err = w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n")) _, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
})) }))
@@ -165,11 +170,12 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
var dataEvents [][]byte var dataEvents [][]byte
var doneEvents int var doneEvents int
for event := range eventChan { for event := range eventChan {
if event.Done { switch {
case event.Done:
doneEvents++ doneEvents++
} else if event.Error != nil { case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error) t.Fatalf("unexpected error: %v", event.Error)
} else { default:
dataEvents = append(dataEvents, event.Data) dataEvents = append(dataEvents, event.Data)
} }
} }
@@ -215,7 +221,8 @@ func TestClient_Send_EmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method) assert.Equal(t, "GET", r.Method)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"result":"ok"}`)) _, err := w.Write([]byte(`{"result":"ok"}`))
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
@@ -238,10 +245,12 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n")) _, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
})) }))
@@ -261,11 +270,12 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
var dataCount int var dataCount int
var doneCount int var doneCount int
for event := range eventChan { for event := range eventChan {
if event.Done { switch {
case event.Done:
doneCount++ doneCount++
} else if event.Error != nil { case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error) t.Fatalf("unexpected error: %v", event.Error)
} else { default:
dataCount++ dataCount++
} }
} }
@@ -279,10 +289,12 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n")) _, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
})) }))
@@ -364,13 +376,14 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
if hijacker, ok := w.(http.Hijacker); ok { if hijacker, ok := w.(http.Hijacker); ok {
conn, _, _ := hijacker.Hijack() conn, _, _ := hijacker.Hijack()
if conn != nil { if conn != nil {
conn.Close() require.NoError(t, conn.Close())
} }
} }
})) }))

View File

@@ -3,10 +3,11 @@ package repository
import ( import (
"time" "time"
"gorm.io/gorm"
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
) )

View File

@@ -3,13 +3,13 @@ package repository
import ( import (
"testing" "testing"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/gorm" "gorm.io/gorm"
testHelpers "nex/backend/tests" testHelpers "nex/backend/tests"
"nex/backend/internal/domain"
) )
func setupTestDB(t *testing.T) *gorm.DB { func setupTestDB(t *testing.T) *gorm.DB {

View File

@@ -3,11 +3,11 @@ package repository
import ( import (
"time" "time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type statsRepository struct { type statsRepository struct {
@@ -19,8 +19,8 @@ func NewStatsRepository(db *gorm.DB) StatsRepository {
} }
func (r *statsRepository) Record(providerID, modelName string) error { func (r *statsRepository) Record(providerID, modelName string) error {
today := time.Now().Format("2006-01-02") now := time.Now()
todayTime, _ := time.Parse("2006-01-02", today) todayTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
stats := config.UsageStats{ stats := config.UsageStats{
ProviderID: providerID, ProviderID: providerID,

View File

@@ -1,11 +1,15 @@
package service package service
import ( import (
"github.com/google/uuid" "errors"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"github.com/google/uuid"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
) )
type modelService struct { type modelService struct {
@@ -108,7 +112,11 @@ func (s *modelService) Delete(id string) error {
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error { func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName) existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil { if err != nil {
return nil // 未找到,不重复 if errors.Is(err, gorm.ErrRecordNotFound) {
return nil // 未找到,不重复
}
return err
} }
if excludeID != "" && existing.ID == excludeID { if excludeID != "" && existing.ID == excludeID {
return nil // 排除自身 return nil // 排除自身

View File

@@ -3,10 +3,10 @@ package service
import ( import (
"strings" "strings"
"nex/backend/pkg/modelid"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/pkg/modelid"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
) )

View File

@@ -4,10 +4,11 @@ import (
"strings" "strings"
"sync" "sync"
"go.uber.org/zap"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger" pkglogger "nex/backend/pkg/logger"
) )
@@ -34,7 +35,9 @@ func NewRoutingCache(
func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) { func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
if v, ok := c.providers.Load(id); ok { if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil if provider, ok := v.(*domain.Provider); ok {
return provider, nil
}
} }
provider, err := c.providerRepo.GetByID(id) provider, err := c.providerRepo.GetByID(id)
@@ -43,7 +46,9 @@ func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
} }
if v, ok := c.providers.Load(id); ok { if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil if provider, ok := v.(*domain.Provider); ok {
return provider, nil
}
} }
c.providers.Store(id, provider) c.providers.Store(id, provider)
@@ -54,7 +59,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
key := providerID + "/" + modelName key := providerID + "/" + modelName
if v, ok := c.models.Load(key); ok { if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil if model, ok := v.(*domain.Model); ok {
return model, nil
}
} }
model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName) model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName)
@@ -63,7 +70,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
} }
if v, ok := c.models.Load(key); ok { if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil if model, ok := v.(*domain.Model); ok {
return model, nil
}
} }
c.models.Store(key, model) c.models.Store(key, model)
@@ -97,7 +106,12 @@ func (c *RoutingCache) invalidateModelsByProvider(providerID string) {
prefix := providerID + "/" prefix := providerID + "/"
count := 0 count := 0
c.models.Range(func(key, value interface{}) bool { c.models.Range(func(key, value interface{}) bool {
if strings.HasPrefix(key.(string), prefix) { keyStr, ok := key.(string)
if !ok {
return true
}
if strings.HasPrefix(keyStr, prefix) {
c.models.Delete(key) c.models.Delete(key)
count++ count++
} }

View File

@@ -5,11 +5,11 @@ import (
"sync" "sync"
"testing" "testing"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/domain"
) )
type mockModelRepo struct { type mockModelRepo struct {
@@ -189,7 +189,8 @@ func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
var openaiCount, anthropicCount int var openaiCount, anthropicCount int
cache.models.Range(func(key, value interface{}) bool { cache.models.Range(func(key, value interface{}) bool {
if key.(string) == "anthropic/claude" { keyStr, ok := key.(string)
if ok && keyStr == "anthropic/claude" {
anthropicCount++ anthropicCount++
} }
return true return true

View File

@@ -1,9 +1,8 @@
package service package service
import ( import (
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain" "nex/backend/internal/domain"
appErrors "nex/backend/pkg/errors"
) )
type routingService struct { type routingService struct {

View File

@@ -3,12 +3,12 @@ package service
import ( import (
"testing" "testing"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
) )
func TestProviderService_Update(t *testing.T) { func TestProviderService_Update(t *testing.T) {
@@ -133,7 +133,9 @@ func TestStatsService_Aggregate_Default(t *testing.T) {
totalCount := 0 totalCount := 0
for _, r := range result { for _, r := range result {
totalCount += r["request_count"].(int) count, ok := r["request_count"].(int)
require.True(t, ok)
totalCount += count
} }
assert.Equal(t, 15, totalCount) assert.Equal(t, 15, totalCount)
} }

View File

@@ -5,6 +5,9 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -13,8 +16,6 @@ import (
testHelpers "nex/backend/tests" testHelpers "nex/backend/tests"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
) )
@@ -134,7 +135,7 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache) svc := NewModelService(modelRepo, providerRepo, cache)
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"} model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
@@ -148,7 +149,7 @@ func TestProviderService_Create_InvalidID(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache) svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
@@ -160,7 +161,7 @@ func TestProviderService_Create_ValidID(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache) svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
@@ -176,7 +177,7 @@ func TestModelService_Update_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache) svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
@@ -202,7 +203,7 @@ func TestModelService_Update_ModelNotFound(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache) svc := NewModelService(modelRepo, providerRepo, cache)
err := svc.Update("nonexistent-id", map[string]interface{}{ err := svc.Update("nonexistent-id", map[string]interface{}{
@@ -215,7 +216,7 @@ func TestModelService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache) svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
@@ -241,7 +242,7 @@ func TestProviderService_Update_ImmutableID(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache) svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
@@ -259,7 +260,7 @@ func TestProviderService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache) svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
@@ -318,7 +319,8 @@ func TestStatsService_Aggregate_ByModel(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db) statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, zap.NewNop()); svc := NewStatsService(statsRepo, buffer) buffer := NewStatsBuffer(statsRepo, zap.NewNop())
svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "model") result := svc.Aggregate(tt.stats, "model")
@@ -379,7 +381,8 @@ func TestStatsService_Aggregate_ByDate(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db) statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, zap.NewNop()); svc := NewStatsService(statsRepo, buffer) buffer := NewStatsBuffer(statsRepo, zap.NewNop())
svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "date") result := svc.Aggregate(tt.stats, "date")
@@ -448,7 +451,7 @@ func TestProviderService_List_APIKeyNotMasked(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache) svc := NewProviderService(repo, modelRepo, cache)
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"} provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
@@ -474,7 +477,7 @@ func TestModelService_ConcurrentCreate(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db) cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache) svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))

View File

@@ -6,9 +6,10 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"nex/backend/internal/repository"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/repository"
pkglogger "nex/backend/pkg/logger" pkglogger "nex/backend/pkg/logger"
) )
@@ -67,13 +68,21 @@ func (b *StatsBuffer) Increment(providerID, modelName string) {
var counter *int64 var counter *int64
if v, ok := b.counters.Load(key); ok { if v, ok := b.counters.Load(key); ok {
counter = v.(*int64) if existing, ok := v.(*int64); ok {
counter = existing
} else {
return
}
} else { } else {
val := int64(0) val := int64(0)
counter = &val counter = &val
actual, loaded := b.counters.LoadOrStore(key, counter) actual, loaded := b.counters.LoadOrStore(key, counter)
if loaded { if loaded {
counter = actual.(*int64) existing, ok := actual.(*int64)
if !ok {
return
}
counter = existing
} }
} }
@@ -117,13 +126,20 @@ func (b *StatsBuffer) flush() {
var entries []statEntry var entries []statEntry
b.counters.Range(func(key, value interface{}) bool { b.counters.Range(func(key, value interface{}) bool {
keyStr := key.(string) keyStr, ok := key.(string)
if !ok {
return true
}
parts := strings.Split(keyStr, "/") parts := strings.Split(keyStr, "/")
if len(parts) != 3 { if len(parts) != 3 {
return true return true
} }
counter := value.(*int64) counter, ok := value.(*int64)
if !ok {
return true
}
count := atomic.SwapInt64(counter, 0) count := atomic.SwapInt64(counter, 0)
if count > 0 { if count > 0 {
@@ -143,8 +159,17 @@ func (b *StatsBuffer) flush() {
success := 0 success := 0
for _, entry := range entries { for _, entry := range entries {
date, _ := time.Parse("2006-01-02", entry.date) date, err := time.Parse("2006-01-02", entry.date)
err := b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count)) if err != nil {
b.logger.Error("解析统计日期失败",
zap.String("provider_id", entry.providerID),
zap.String("model_name", entry.modelName),
zap.String("date", entry.date),
zap.Error(err))
continue
}
err = b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
if err != nil { if err != nil {
b.logger.Error("批量更新统计失败", b.logger.Error("批量更新统计失败",
zap.String("provider_id", entry.providerID), zap.String("provider_id", entry.providerID),
@@ -154,8 +179,10 @@ func (b *StatsBuffer) flush() {
key := entry.providerID + "/" + entry.modelName + "/" + entry.date key := entry.providerID + "/" + entry.modelName + "/" + entry.date
if v, ok := b.counters.Load(key); ok { if v, ok := b.counters.Load(key); ok {
counter := v.(*int64) counter, ok := v.(*int64)
atomic.AddInt64(counter, entry.count) if ok {
atomic.AddInt64(counter, entry.count)
}
} }
} else { } else {
success++ success++

View File

@@ -7,10 +7,10 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/domain"
) )
type mockStatsRepo struct { type mockStatsRepo struct {
@@ -58,8 +58,10 @@ func TestStatsBuffer_Increment(t *testing.T) {
var count int64 var count int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
count += atomic.LoadInt64(counter) if ok {
count += atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(3), count) assert.Equal(t, int64(3), count)
@@ -82,8 +84,10 @@ func TestStatsBuffer_ConcurrentIncrement(t *testing.T) {
var count int64 var count int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
count = atomic.LoadInt64(counter) if ok {
count = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(100), count) assert.Equal(t, int64(100), count)
@@ -161,8 +165,10 @@ func TestStatsBuffer_SwapInt64(t *testing.T) {
var beforeCount int64 var beforeCount int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
beforeCount = atomic.LoadInt64(counter) if ok {
beforeCount = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(2), beforeCount) assert.Equal(t, int64(2), beforeCount)
@@ -171,8 +177,10 @@ func TestStatsBuffer_SwapInt64(t *testing.T) {
var afterCount int64 var afterCount int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
afterCount = atomic.LoadInt64(counter) if ok {
afterCount = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(0), afterCount) assert.Equal(t, int64(0), afterCount)
@@ -190,8 +198,10 @@ func TestStatsBuffer_FailRetry(t *testing.T) {
var count int64 var count int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
count = atomic.LoadInt64(counter) if ok {
count = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(2), count) assert.Equal(t, int64(2), count)

View File

@@ -1,6 +1,7 @@
package errors package errors
import ( import (
stderrors "errors"
"fmt" "fmt"
"net/http" "net/http"
) )
@@ -70,22 +71,11 @@ func AsAppError(err error) (*AppError, bool) {
if err == nil { if err == nil {
return nil, false return nil, false
} }
var appErr *AppError
if ok := is(err, &appErr); ok {
return appErr, true
}
return nil, false
}
func is(err error, target interface{}) bool { var appErr *AppError
// 简单的类型断言 if !stderrors.As(err, &appErr) {
if e, ok := err.(*AppError); ok { return nil, false
// 直接赋值
switch t := target.(type) {
case **AppError:
*t = e
return true
}
} }
return false
return appErr, true
} }

View File

@@ -104,7 +104,8 @@ func TestPredefinedErrors(t *testing.T) {
func TestAsAppError(t *testing.T) { func TestAsAppError(t *testing.T) {
t.Run("nil输入", func(t *testing.T) { t.Run("nil输入", func(t *testing.T) {
_, ok := AsAppError(nil) appErr, ok := AsAppError(nil)
assert.Nil(t, appErr)
assert.False(t, ok) assert.False(t, ok)
}) })
@@ -122,7 +123,8 @@ func TestAsAppError(t *testing.T) {
}) })
t.Run("非AppError类型", func(t *testing.T) { t.Run("非AppError类型", func(t *testing.T) {
_, ok := AsAppError(errors.New("普通错误")) appErr, ok := AsAppError(errors.New("普通错误"))
assert.Nil(t, appErr)
assert.False(t, ok) assert.False(t, ok)
}) })
} }

View File

@@ -19,7 +19,7 @@ func TestNew_StdoutOnly(t *testing.T) {
func TestNew_WithFileOutput(t *testing.T) { func TestNew_WithFileOutput(t *testing.T) {
dir := filepath.Join(os.TempDir(), "nex-logger-test") dir := filepath.Join(os.TempDir(), "nex-logger-test")
os.MkdirAll(dir, 0755) require.NoError(t, os.MkdirAll(dir, 0o755))
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
logger, err := New(Config{ logger, err := New(Config{
@@ -81,7 +81,7 @@ func TestParseLevel(t *testing.T) {
{"info", true}, {"info", true},
{"warn", true}, {"warn", true},
{"error", true}, {"error", true},
{"", true}, // 默认为 info {"", true}, // 默认为 info
{"invalid", true}, // 默认为 info {"invalid", true}, // 默认为 info
} }
for _, tt := range tests { for _, tt := range tests {

View File

@@ -22,9 +22,9 @@ func newRotateWriter(cfg Config) *lumberjack.Logger {
return &lumberjack.Logger{ return &lumberjack.Logger{
Filename: logFilePath(cfg.Path), Filename: logFilePath(cfg.Path),
MaxSize: maxSize, // MB MaxSize: maxSize, // MB
MaxBackups: maxBackups, MaxBackups: maxBackups,
MaxAge: maxAge, // days MaxAge: maxAge, // days
Compress: cfg.Compress, Compress: cfg.Compress,
} }
} }

View File

@@ -6,10 +6,10 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"nex/backend/internal/config"
) )
func TestLoadConfig_DefaultValues(t *testing.T) { func TestLoadConfig_DefaultValues(t *testing.T) {
@@ -72,7 +72,7 @@ log:
max_age: 7 max_age: 7
compress: false compress: false
` `
err := os.WriteFile(configPath, []byte(yamlContent), 0644) err := os.WriteFile(configPath, []byte(yamlContent), 0o600)
require.NoError(t, err) require.NoError(t, err)
cfg, err := config.LoadConfigFromPath(configPath) cfg, err := config.LoadConfigFromPath(configPath)
@@ -103,7 +103,7 @@ server:
log: log:
level: warn level: warn
` `
err := os.WriteFile(configPath, []byte(yamlContent), 0644) err := os.WriteFile(configPath, []byte(yamlContent), 0o600)
require.NoError(t, err) require.NoError(t, err)
t.Setenv("NEX_SERVER_PORT", "9000") t.Setenv("NEX_SERVER_PORT", "9000")
@@ -147,7 +147,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
} }
defer func() { defer func() {
if originalConfig != nil { if originalConfig != nil {
_ = os.WriteFile(configPath, originalConfig, 0644) require.NoError(t, os.WriteFile(configPath, originalConfig, 0o600))
} }
}() }()

View File

@@ -11,20 +11,21 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/provider"
"nex/backend/internal/repository"
"nex/backend/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"gorm.io/gorm" "gorm.io/gorm"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
openaiConv "nex/backend/internal/conversion/openai" openaiConv "nex/backend/internal/conversion/openai"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/provider"
"nex/backend/internal/repository"
"nex/backend/internal/service"
) )
func init() { func init() {
@@ -39,7 +40,8 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 默认返回成功,由各测试 case 覆盖 // 默认返回成功,由各测试 case 覆盖
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"error":"not mocked"}`)) _, err := w.Write([]byte(`{"error":"not mocked"}`))
require.NoError(t, err)
})) }))
db := setupTestDB(t) db := setupTestDB(t)
@@ -124,7 +126,6 @@ func createProviderAndModel(t *testing.T, r *gin.Engine, providerID, protocol, m
require.Equal(t, 201, w.Code) require.Equal(t, 201, w.Code)
modelBody, _ := json.Marshal(map[string]string{ modelBody, _ := json.Marshal(map[string]string{
"provider_id": providerID, "provider_id": providerID,
"model_name": modelName, "model_name": modelName,
}) })
@@ -143,9 +144,10 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
// 配置上游返回 Anthropic 格式响应 // 配置上游返回 Anthropic 格式响应
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求被转换为 Anthropic 格式 // 验证请求被转换为 Anthropic 格式
body, _ := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any var req map[string]any
json.Unmarshal(body, &req) require.NoError(t, json.Unmarshal(body, &req))
assert.Equal(t, "/v1/messages", r.URL.Path) assert.Equal(t, "/v1/messages", r.URL.Path)
assert.Contains(t, r.Header.Get("Content-Type"), "application/json") assert.Contains(t, r.Header.Get("Content-Type"), "application/json")
@@ -166,7 +168,7 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
}, },
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp) require.NoError(t, json.NewEncoder(w).Encode(resp))
}) })
createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL) createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
@@ -189,13 +191,16 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var resp map[string]any var resp map[string]any
json.Unmarshal(w.Body.Bytes(), &resp) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "chat.completion", resp["object"]) assert.Equal(t, "chat.completion", resp["object"])
choices := resp["choices"].([]any) choices, ok := resp["choices"].([]any)
require.True(t, ok)
require.Len(t, choices, 1) require.Len(t, choices, 1)
choice := choices[0].(map[string]any) choice, ok := choices[0].(map[string]any)
msg := choice["message"].(map[string]any) require.True(t, ok)
msg, ok := choice["message"].(map[string]any)
require.True(t, ok)
assert.Contains(t, msg["content"], "Hello from Anthropic!") assert.Contains(t, msg["content"], "Hello from Anthropic!")
} }
@@ -203,9 +208,10 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
r, _, upstream := setupConversionTest(t) r, _, upstream := setupConversionTest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any var req map[string]any
json.Unmarshal(body, &req) require.NoError(t, json.Unmarshal(body, &req))
assert.Equal(t, "/chat/completions", r.URL.Path) assert.Equal(t, "/chat/completions", r.URL.Path)
assert.Contains(t, r.Header.Get("Authorization"), "Bearer test-key") assert.Contains(t, r.Header.Get("Authorization"), "Bearer test-key")
@@ -229,7 +235,7 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
}, },
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp) require.NoError(t, json.NewEncoder(w).Encode(resp))
}) })
createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL) createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
@@ -252,12 +258,14 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var resp map[string]any var resp map[string]any
json.Unmarshal(w.Body.Bytes(), &resp) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "message", resp["type"]) assert.Equal(t, "message", resp["type"])
content := resp["content"].([]any) content, ok := resp["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1) require.Len(t, content, 1)
block := content[0].(map[string]any) block, ok2 := content[0].(map[string]any)
require.True(t, ok2)
assert.Contains(t, block["text"], "Hello from OpenAI!") assert.Contains(t, block["text"], "Hello from OpenAI!")
} }
@@ -269,21 +277,23 @@ func TestConversion_OpenAIToOpenAI_Passthrough(t *testing.T) {
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/chat/completions", r.URL.Path) assert.Equal(t, "/chat/completions", r.URL.Path)
body, _ := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any var req map[string]any
json.Unmarshal(body, &req) require.NoError(t, json.Unmarshal(body, &req))
// Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名 // Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名
assert.Equal(t, "gpt-4", req["model"]) assert.Equal(t, "gpt-4", req["model"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
// 上游返回上游模型名 // 上游返回上游模型名
w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`)) _, err = w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`))
require.NoError(t, err)
}) })
createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL) createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
reqBody := map[string]any{ reqBody := map[string]any{
"model": "openai_p/gpt-4", // 客户端发送统一 ID "model": "openai_p/gpt-4", // 客户端发送统一 ID
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
} }
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(reqBody)
@@ -304,21 +314,23 @@ func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) {
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/messages", r.URL.Path) assert.Equal(t, "/v1/messages", r.URL.Path)
body, _ := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any var req map[string]any
json.Unmarshal(body, &req) require.NoError(t, json.Unmarshal(body, &req))
// Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名 // Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名
assert.Equal(t, "claude-3-opus", req["model"]) assert.Equal(t, "claude-3-opus", req["model"])
// 上游返回上游模型名 // 上游返回上游模型名
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`)) _, err = w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`))
require.NoError(t, err)
}) })
createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL) createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
reqBody := map[string]any{ reqBody := map[string]any{
"model": "anthropic_p/claude-3-opus", // 客户端发送统一 ID "model": "anthropic_p/claude-3-opus", // 客户端发送统一 ID
"max_tokens": 1024, "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
} }
@@ -352,7 +364,8 @@ func TestConversion_OpenAIToAnthropic_Stream(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
if f, ok := w.(http.Flusher); ok { if f, ok := w.(http.Flusher); ok {
f.Flush() f.Flush()
} }
@@ -393,7 +406,8 @@ func TestConversion_AnthropicToOpenAI_Stream(t *testing.T) {
"data: [DONE]\n\n", "data: [DONE]\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
if f, ok := w.(http.Flusher); ok { if f, ok := w.(http.Flusher); ok {
f.Flush() f.Flush()
} }
@@ -447,11 +461,13 @@ func TestConversion_Models_CrossProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
var anthropicResp map[string]any var anthropicResp map[string]any
json.Unmarshal(anthropicBody, &anthropicResp) require.NoError(t, json.Unmarshal(anthropicBody, &anthropicResp))
data := anthropicResp["data"].([]any) data, okd := anthropicResp["data"].([]any)
require.True(t, okd)
assert.Len(t, data, 2) assert.Len(t, data, 2)
first := data[0].(map[string]any) first, okf := data[0].(map[string]any)
require.True(t, okf)
assert.Equal(t, "gpt-4", first["id"]) assert.Equal(t, "gpt-4", first["id"])
assert.Equal(t, "model", first["type"]) assert.Equal(t, "model", first["type"])
@@ -466,11 +482,12 @@ func TestConversion_Models_CrossProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
var openaiResp map[string]any var openaiResp map[string]any
json.Unmarshal(openaiBody, &err) require.NoError(t, json.Unmarshal(openaiBody, &openaiResp))
json.Unmarshal(openaiBody, &openaiResp) oaiData, oki := openaiResp["data"].([]any)
oaiData := openaiResp["data"].([]any) require.True(t, oki)
assert.Len(t, oaiData, 1) assert.Len(t, oaiData, 1)
firstOai := oaiData[0].(map[string]any) firstOai, okf2 := oaiData[0].(map[string]any)
require.True(t, okf2)
assert.Equal(t, "claude-3-opus", firstOai["id"]) assert.Equal(t, "claude-3-opus", firstOai["id"])
} }
@@ -537,7 +554,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
require.Equal(t, 201, w.Code) require.Equal(t, 201, w.Code)
var created map[string]any var created map[string]any
json.Unmarshal(w.Body.Bytes(), &created) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &created))
assert.Equal(t, "sk-test", created["api_key"]) assert.Equal(t, "sk-test", created["api_key"])
// 获取时应包含 protocol // 获取时应包含 protocol
@@ -547,7 +564,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var fetched map[string]any var fetched map[string]any
json.Unmarshal(w.Body.Bytes(), &fetched) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &fetched))
assert.Equal(t, "anthropic", fetched["protocol"]) assert.Equal(t, "anthropic", fetched["protocol"])
} }
@@ -570,11 +587,13 @@ func TestConversion_ProviderDefaultProtocol(t *testing.T) {
require.Equal(t, 201, w.Code) require.Equal(t, 201, w.Code)
var created map[string]any var created map[string]any
json.Unmarshal(w.Body.Bytes(), &created) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &created))
assert.Equal(t, "openai", created["protocol"]) assert.Equal(t, "openai", created["protocol"])
} }
// Suppress unused imports // Suppress unused imports
var _ = fmt.Sprintf var (
var _ = strings.Contains _ = fmt.Sprintf
var _ = time.Second _ = strings.Contains
_ = time.Second
)

View File

@@ -12,19 +12,20 @@ import (
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic" "nex/backend/internal/conversion/anthropic"
openaiConv "nex/backend/internal/conversion/openai"
"nex/backend/internal/handler" "nex/backend/internal/handler"
"nex/backend/internal/handler/middleware" "nex/backend/internal/handler/middleware"
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/internal/service" "nex/backend/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
openaiConv "nex/backend/internal/conversion/openai"
) )
func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) { func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
@@ -33,7 +34,8 @@ func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"error":"not mocked"}`)) _, err := w.Write([]byte(`{"error":"not mocked"}`))
require.NoError(t, err)
})) }))
db := setupTestDB(t) db := setupTestDB(t)
@@ -115,11 +117,12 @@ func parseSSEEvents(body string) []map[string]string {
var currentEvent, currentData string var currentEvent, currentData string
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if strings.HasPrefix(line, "event: ") { switch {
case strings.HasPrefix(line, "event: "):
currentEvent = strings.TrimPrefix(line, "event: ") currentEvent = strings.TrimPrefix(line, "event: ")
} else if strings.HasPrefix(line, "data: ") { case strings.HasPrefix(line, "data: "):
currentData = strings.TrimPrefix(line, "data: ") currentData = strings.TrimPrefix(line, "data: ")
} else if line == "" && (currentEvent != "" || currentData != "") { case line == "" && (currentEvent != "" || currentData != ""):
events = append(events, map[string]string{ events = append(events, map[string]string{
"event": currentEvent, "event": currentEvent,
"data": currentData, "data": currentData,
@@ -157,21 +160,21 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) {
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.Equal(t, "/chat/completions", req.URL.Path) assert.Equal(t, "/chat/completions", req.URL.Path)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-001", "id": "chatcmpl-e2e-001",
"object": "chat.completion", "object": "chat.completion",
"created": 1700000000, "created": 1700000000,
"model": "gpt-4o", "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
"message": map[string]any{"role": "assistant", "content": "你好我是AI助手。"}, "message": map[string]any{"role": "assistant", "content": "你好我是AI助手。"},
"finish_reason": "stop", "finish_reason": "stop",
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{ "usage": map[string]any{
"prompt_tokens": 15, "completion_tokens": 10, "total_tokens": 25, "prompt_tokens": 15, "completion_tokens": 10, "total_tokens": 25,
}, },
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -210,21 +213,23 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) {
func TestE2E_OpenAI_NonStream_MultiTurn(t *testing.T) { func TestE2E_OpenAI_NonStream_MultiTurn(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
msgs := reqBody["messages"].([]any) msgs, ok := reqBody["messages"].([]any)
require.True(t, ok)
assert.GreaterOrEqual(t, len(msgs), 3) assert.GreaterOrEqual(t, len(msgs), 3)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-002", "object": "chat.completion", "created": 1700000001, "model": "gpt-4o", "id": "chatcmpl-e2e-002", "object": "chat.completion", "created": 1700000001, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "message": map[string]any{"role": "assistant", "content": "Go语言的interface是隐式实现的。"}, "index": 0, "message": map[string]any{"role": "assistant", "content": "Go语言的interface是隐式实现的。"},
"finish_reason": "stop", "logprobs": nil, "finish_reason": "stop", "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120}, "usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -252,7 +257,7 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-004", "object": "chat.completion", "created": 1700000003, "model": "gpt-4o", "id": "chatcmpl-e2e-004", "object": "chat.completion", "created": 1700000003, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
@@ -272,7 +277,7 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 80, "completion_tokens": 18, "total_tokens": 98}, "usage": map[string]any{"prompt_tokens": 80, "completion_tokens": 18, "total_tokens": 98},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -286,9 +291,9 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
"function": map[string]any{ "function": map[string]any{
"name": "get_weather", "description": "获取天气", "name": "get_weather", "description": "获取天气",
"parameters": map[string]any{ "parameters": map[string]any{
"type": "object", "type": "object",
"properties": map[string]any{"city": map[string]any{"type": "string"}}, "properties": map[string]any{"city": map[string]any{"type": "string"}},
"required": []string{"city"}, "required": []string{"city"},
}, },
}, },
}}, }},
@@ -319,22 +324,22 @@ func TestE2E_OpenAI_NonStream_MaxTokens_Length(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-014", "object": "chat.completion", "created": 1700000014, "model": "gpt-4o", "id": "chatcmpl-e2e-014", "object": "chat.completion", "created": 1700000014, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
"message": map[string]any{"role": "assistant", "content": "人工智能起源于1950年代..."}, "message": map[string]any{"role": "assistant", "content": "人工智能起源于1950年代..."},
"finish_reason": "length", "finish_reason": "length",
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 30, "total_tokens": 50}, "usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 30, "total_tokens": 50},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}}, "messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}},
"max_tokens": 30, "max_tokens": 30,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -353,11 +358,11 @@ func TestE2E_OpenAI_NonStream_UsageWithReasoning(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-022", "object": "chat.completion", "created": 1700000022, "model": "o3", "id": "chatcmpl-e2e-022", "object": "chat.completion", "created": 1700000022, "model": "o3",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
"message": map[string]any{"role": "assistant", "content": "答案是61。"}, "message": map[string]any{"role": "assistant", "content": "答案是61。"},
"finish_reason": "stop", "finish_reason": "stop",
"logprobs": nil, "logprobs": nil,
}}, }},
@@ -365,12 +370,12 @@ func TestE2E_OpenAI_NonStream_UsageWithReasoning(t *testing.T) {
"prompt_tokens": 35, "completion_tokens": 48, "total_tokens": 83, "prompt_tokens": 35, "completion_tokens": 48, "total_tokens": 83,
"completion_tokens_details": map[string]any{"reasoning_tokens": 20}, "completion_tokens_details": map[string]any{"reasoning_tokens": 20},
}, },
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "o3", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "o3", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/o3", "model": "openai_p/o3",
"messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}}, "messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -393,12 +398,12 @@ func TestE2E_OpenAI_NonStream_Refusal(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-007", "object": "chat.completion", "created": 1700000007, "model": "gpt-4o", "id": "chatcmpl-e2e-007", "object": "chat.completion", "created": 1700000007, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
"message": map[string]any{ "message": map[string]any{
"role": "assistant", "role": "assistant",
"content": nil, "content": nil,
"refusal": "抱歉,我无法提供涉及危险活动的信息。", "refusal": "抱歉,我无法提供涉及危险活动的信息。",
}, },
@@ -406,12 +411,12 @@ func TestE2E_OpenAI_NonStream_Refusal(t *testing.T) {
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 35, "total_tokens": 47}, "usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 35, "total_tokens": 47},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "做坏事"}}, "messages": []map[string]any{{"role": "user", "content": "做坏事"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -453,9 +458,9 @@ func TestE2E_OpenAI_Stream_Text(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "你好"}}, "messages": []map[string]any{{"role": "user", "content": "你好"}},
"stream": true, "stream": true,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
@@ -497,14 +502,14 @@ func TestE2E_OpenAI_Stream_ToolCalls(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"type": "function", "type": "function",
"function": map[string]any{ "function": map[string]any{
"name": "get_weather", "description": "获取天气", "name": "get_weather", "description": "获取天气",
"parameters": map[string]any{ "parameters": map[string]any{
"type": "object", "type": "object",
"properties": map[string]any{"city": map[string]any{"type": "string"}}, "properties": map[string]any{"city": map[string]any{"type": "string"}},
}, },
}, },
@@ -546,9 +551,9 @@ func TestE2E_OpenAI_Stream_WithUsage(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "hi"}}, "messages": []map[string]any{{"role": "user", "content": "hi"}},
"stream": true, "stream": true,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
@@ -569,14 +574,14 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_001", "type": "message", "role": "assistant", "id": "msg_e2e_001", "type": "message", "role": "assistant",
"content": []map[string]any{ "content": []map[string]any{
{"type": "text", "text": "你好我是Claude由Anthropic开发的AI助手。"}, {"type": "text", "text": "你好我是Claude由Anthropic开发的AI助手。"},
}, },
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 15, "output_tokens": 25}, "usage": map[string]any{"input_tokens": 15, "output_tokens": 25},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -611,24 +616,25 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) {
func TestE2E_Anthropic_NonStream_WithSystem(t *testing.T) { func TestE2E_Anthropic_NonStream_WithSystem(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
assert.NotNil(t, reqBody["system"]) assert.NotNil(t, reqBody["system"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_003", "type": "message", "role": "assistant", "id": "msg_e2e_003", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "递归是函数调用自身。"}}, "content": []map[string]any{{"type": "text", "text": "递归是函数调用自身。"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 30, "output_tokens": 15}, "usage": map[string]any{"input_tokens": 30, "output_tokens": 15},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"system": "你是编程助手", "system": "你是编程助手",
"messages": []map[string]any{{"role": "user", "content": "什么是递归?"}}, "messages": []map[string]any{{"role": "user", "content": "什么是递归?"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -643,7 +649,7 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_009", "type": "message", "role": "assistant", "id": "msg_e2e_009", "type": "message", "role": "assistant",
"content": []map[string]any{{ "content": []map[string]any{{
"type": "tool_use", "id": "toolu_e2e_009", "name": "get_weather", "type": "tool_use", "id": "toolu_e2e_009", "name": "get_weather",
@@ -651,7 +657,7 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
}}, }},
"model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 180, "output_tokens": 42}, "usage": map[string]any{"input_tokens": 180, "output_tokens": 42},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -661,9 +667,9 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
"tools": []map[string]any{{ "tools": []map[string]any{{
"name": "get_weather", "description": "获取天气", "name": "get_weather", "description": "获取天气",
"input_schema": map[string]any{ "input_schema": map[string]any{
"type": "object", "type": "object",
"properties": map[string]any{"city": map[string]any{"type": "string"}}, "properties": map[string]any{"city": map[string]any{"type": "string"}},
"required": []string{"city"}, "required": []string{"city"},
}, },
}}, }},
"tool_choice": map[string]any{"type": "auto"}, "tool_choice": map[string]any{"type": "auto"},
@@ -689,7 +695,7 @@ func TestE2E_Anthropic_NonStream_Thinking(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_018", "type": "message", "role": "assistant", "id": "msg_e2e_018", "type": "message", "role": "assistant",
"content": []map[string]any{ "content": []map[string]any{
{"type": "thinking", "thinking": "这是一个逻辑推理问题..."}, {"type": "thinking", "thinking": "这是一个逻辑推理问题..."},
@@ -697,7 +703,7 @@ func TestE2E_Anthropic_NonStream_Thinking(t *testing.T) {
}, },
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 95, "output_tokens": 280}, "usage": map[string]any{"input_tokens": 95, "output_tokens": 280},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -724,12 +730,12 @@ func TestE2E_Anthropic_NonStream_MaxTokens(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_016", "type": "message", "role": "assistant", "id": "msg_e2e_016", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "人工智能起源于..."}}, "content": []map[string]any{{"type": "text", "text": "人工智能起源于..."}},
"model": "claude-opus-4-7", "stop_reason": "max_tokens", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "max_tokens", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 22, "output_tokens": 20}, "usage": map[string]any{"input_tokens": 22, "output_tokens": 20},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -752,18 +758,18 @@ func TestE2E_Anthropic_NonStream_StopSequence(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_017", "type": "message", "role": "assistant", "id": "msg_e2e_017", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "1\n2\n3\n4\n"}}, "content": []map[string]any{{"type": "text", "text": "1\n2\n3\n4\n"}},
"model": "claude-opus-4-7", "stop_reason": "stop_sequence", "stop_sequence": "5", "model": "claude-opus-4-7", "stop_reason": "stop_sequence", "stop_sequence": "5",
"usage": map[string]any{"input_tokens": 22, "output_tokens": 10}, "usage": map[string]any{"input_tokens": 22, "output_tokens": 10},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}}, "messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
"stop_sequences": []string{"5"}, "stop_sequences": []string{"5"},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -781,19 +787,20 @@ func TestE2E_Anthropic_NonStream_StopSequence(t *testing.T) {
func TestE2E_Anthropic_NonStream_MetadataUserID(t *testing.T) { func TestE2E_Anthropic_NonStream_MetadataUserID(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
metadata, _ := reqBody["metadata"].(map[string]any) metadata, _ := reqBody["metadata"].(map[string]any)
assert.Equal(t, "user_12345", metadata["user_id"]) assert.Equal(t, "user_12345", metadata["user_id"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_026", "type": "message", "role": "assistant", "id": "msg_e2e_026", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "你好!"}}, "content": []map[string]any{{"type": "text", "text": "你好!"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 12, "output_tokens": 5}, "usage": map[string]any{"input_tokens": 12, "output_tokens": 5},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -814,21 +821,21 @@ func TestE2E_Anthropic_NonStream_UsageWithCache(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_025", "type": "message", "role": "assistant", "id": "msg_e2e_025", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "你好!"}}, "content": []map[string]any{{"type": "text", "text": "你好!"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{ "usage": map[string]any{
"input_tokens": 25, "output_tokens": 5, "input_tokens": 25, "output_tokens": 5,
"cache_creation_input_tokens": 15, "cache_read_input_tokens": 0, "cache_creation_input_tokens": 15, "cache_read_input_tokens": 0,
}, },
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"system": []map[string]any{{"type": "text", "text": "你是编程助手。"}}, "system": []map[string]any{{"type": "text", "text": "你是编程助手。"}},
"messages": []map[string]any{{"role": "user", "content": "你好"}}, "messages": []map[string]any{{"role": "user", "content": "你好"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -864,7 +871,8 @@ func TestE2E_Anthropic_Stream_Text(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
@@ -874,7 +882,7 @@ func TestE2E_Anthropic_Stream_Text(t *testing.T) {
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "你好"}}, "messages": []map[string]any{{"role": "user", "content": "你好"}},
"stream": true, "stream": true,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
@@ -922,7 +930,7 @@ func TestE2E_Anthropic_Stream_Thinking(t *testing.T) {
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 4096, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 4096,
"messages": []map[string]any{{"role": "user", "content": "1+1=?"}}, "messages": []map[string]any{{"role": "user", "content": "1+1=?"}},
"thinking": map[string]any{"type": "enabled", "budget_tokens": 1024}, "thinking": map[string]any{"type": "enabled", "budget_tokens": 1024},
"stream": true, "stream": true,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
@@ -961,14 +969,14 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_RequestFormat(t *testing.T) {
json.NewEncoder(w).Encode(map[string]any{ json.NewEncoder(w).Encode(map[string]any{
"id": "msg_cross_001", "type": "message", "role": "assistant", "id": "msg_cross_001", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "跨协议响应"}}, "content": []map[string]any{{"type": "text", "text": "跨协议响应"}},
"model": "claude-model", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-model", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 10, "output_tokens": 5}, "usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-model", "model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "Hello"}}, "messages": []map[string]any{{"role": "user", "content": "Hello"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1050,9 +1058,9 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_Stream(t *testing.T) {
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-model", "model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "Hello"}}, "messages": []map[string]any{{"role": "user", "content": "Hello"}},
"stream": true, "stream": true,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
@@ -1092,7 +1100,7 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream(t *testing.T) {
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4", "max_tokens": 1024, "model": "openai_p/gpt-4", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "Hello"}}, "messages": []map[string]any{{"role": "user", "content": "Hello"}},
"stream": true, "stream": true,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
@@ -1128,7 +1136,7 @@ func TestE2E_OpenAI_ErrorResponse(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "nonexistent", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "nonexistent", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/nonexistent", "model": "openai_p/nonexistent",
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1183,11 +1191,11 @@ func TestE2E_OpenAI_NonStream_ParallelToolCalls(t *testing.T) {
"content": nil, "content": nil,
"tool_calls": []map[string]any{ "tool_calls": []map[string]any{
{ {
"id": "call_ptc_1", "type": "function", "id": "call_ptc_1", "type": "function",
"function": map[string]any{"name": "get_weather", "arguments": `{"city":"北京"}`}, "function": map[string]any{"name": "get_weather", "arguments": `{"city":"北京"}`},
}, },
{ {
"id": "call_ptc_2", "type": "function", "id": "call_ptc_2", "type": "function",
"function": map[string]any{"name": "get_weather", "arguments": `{"city":"上海"}`}, "function": map[string]any{"name": "get_weather", "arguments": `{"city":"上海"}`},
}, },
}, },
@@ -1201,7 +1209,7 @@ func TestE2E_OpenAI_NonStream_ParallelToolCalls(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"type": "function", "type": "function",
@@ -1242,10 +1250,10 @@ func TestE2E_OpenAI_NonStream_StopSequence(t *testing.T) {
json.NewEncoder(w).Encode(map[string]any{ json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-stop", "object": "chat.completion", "created": 1700000060, "model": "gpt-4o", "id": "chatcmpl-e2e-stop", "object": "chat.completion", "created": 1700000060, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
"message": map[string]any{"role": "assistant", "content": "1, 2, 3, 4, "}, "message": map[string]any{"role": "assistant", "content": "1, 2, 3, 4, "},
"finish_reason": "stop", "finish_reason": "stop",
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}, "usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18},
}) })
@@ -1253,9 +1261,9 @@ func TestE2E_OpenAI_NonStream_StopSequence(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}}, "messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
"stop": []string{"5"}, "stop": []string{"5"},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
@@ -1291,7 +1299,7 @@ func TestE2E_OpenAI_NonStream_ContentFilter(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "危险内容"}}, "messages": []map[string]any{{"role": "user", "content": "危险内容"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1353,21 +1361,22 @@ func TestE2E_Anthropic_NonStream_MultiToolUse(t *testing.T) {
func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) { func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
tc, _ := reqBody["tool_choice"].(map[string]any) tc, _ := reqBody["tool_choice"].(map[string]any)
assert.Equal(t, "any", tc["type"]) assert.Equal(t, "any", tc["type"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_tca", "type": "message", "role": "assistant", "id": "msg_e2e_tca", "type": "message", "role": "assistant",
"content": []map[string]any{ "content": []map[string]any{
{"type": "tool_use", "id": "toolu_tca_1", "name": "get_time", "input": map[string]any{"timezone": "Asia/Shanghai"}}, {"type": "tool_use", "id": "toolu_tca_1", "name": "get_time", "input": map[string]any{"timezone": "Asia/Shanghai"}},
}, },
"model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 100, "output_tokens": 30}, "usage": map[string]any{"input_tokens": 100, "output_tokens": 30},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -1397,20 +1406,21 @@ func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) {
func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) { func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
sys, ok := reqBody["system"].([]any) sys, ok := reqBody["system"].([]any)
require.True(t, ok, "system should be an array") require.True(t, ok, "system should be an array")
require.GreaterOrEqual(t, len(sys), 1) require.GreaterOrEqual(t, len(sys), 1)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_asys", "type": "message", "role": "assistant", "id": "msg_e2e_asys", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "已收到多条系统指令。"}}, "content": []map[string]any{{"type": "text", "text": "已收到多条系统指令。"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 50, "output_tokens": 10}, "usage": map[string]any{"input_tokens": 50, "output_tokens": 10},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -1433,21 +1443,22 @@ func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) {
func TestE2E_Anthropic_NonStream_ToolResultMessage(t *testing.T) { func TestE2E_Anthropic_NonStream_ToolResultMessage(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
msgs := reqBody["messages"].([]any) msgs := reqBody["messages"].([]any)
require.GreaterOrEqual(t, len(msgs), 3) require.GreaterOrEqual(t, len(msgs), 3)
lastMsg := msgs[len(msgs)-1].(map[string]any) lastMsg := msgs[len(msgs)-1].(map[string]any)
assert.Equal(t, "user", lastMsg["role"]) assert.Equal(t, "user", lastMsg["role"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_tr", "type": "message", "role": "assistant", "id": "msg_e2e_tr", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "北京当前晴天温度25°C。"}}, "content": []map[string]any{{"type": "text", "text": "北京当前晴天温度25°C。"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil, "model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 150, "output_tokens": 20}, "usage": map[string]any{"input_tokens": 150, "output_tokens": 20},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -1497,7 +1508,8 @@ func TestE2E_Anthropic_Stream_ToolCalls(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
@@ -1559,7 +1571,7 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_NonStream_ToolCalls(t *testing.T) {
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-model", "model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"type": "function", "type": "function",
@@ -1634,14 +1646,14 @@ func TestE2E_CrossProtocol_StopReasonMapping(t *testing.T) {
json.NewEncoder(w).Encode(map[string]any{ json.NewEncoder(w).Encode(map[string]any{
"id": "msg_cross_stop", "type": "message", "role": "assistant", "id": "msg_cross_stop", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "被截断的内容..."}}, "content": []map[string]any{{"type": "text", "text": "被截断的内容..."}},
"model": "claude-model", "stop_reason": "max_tokens", "stop_sequence": nil, "model": "claude-model", "stop_reason": "max_tokens", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 10, "output_tokens": 20}, "usage": map[string]any{"input_tokens": 10, "output_tokens": 20},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-model", "model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "长文"}}, "messages": []map[string]any{{"role": "user", "content": "长文"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1659,9 +1671,10 @@ func TestE2E_CrossProtocol_StopReasonMapping(t *testing.T) {
func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) { func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) {
r, upstream := setupE2ETest(t) r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any var reqBody map[string]any
json.Unmarshal(body, &reqBody) require.NoError(t, json.Unmarshal(body, &reqBody))
msgs := reqBody["messages"].([]any) msgs := reqBody["messages"].([]any)
require.GreaterOrEqual(t, len(msgs), 3) require.GreaterOrEqual(t, len(msgs), 3)
toolMsg := msgs[2].(map[string]any) toolMsg := msgs[2].(map[string]any)
@@ -1669,16 +1682,16 @@ func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) {
assert.Equal(t, "call_e2e_001", toolMsg["tool_call_id"]) assert.Equal(t, "call_e2e_001", toolMsg["tool_call_id"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{ require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-tr", "object": "chat.completion", "created": 1700000080, "model": "gpt-4o", "id": "chatcmpl-e2e-tr", "object": "chat.completion", "created": 1700000080, "model": "gpt-4o",
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
"message": map[string]any{"role": "assistant", "content": "北京当前晴天温度25°C。"}, "message": map[string]any{"role": "assistant", "content": "北京当前晴天温度25°C。"},
"finish_reason": "stop", "finish_reason": "stop",
"logprobs": nil, "logprobs": nil,
}}, }},
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120}, "usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
}) }))
}) })
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -1722,7 +1735,8 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream_ToolCalls(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
@@ -1730,7 +1744,7 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream_ToolCalls(t *testing.T) {
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-model", "model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"type": "function", "type": "function",
@@ -1817,7 +1831,7 @@ func TestE2E_OpenAI_Upstream5xx_ErrorPassthrough(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1879,7 +1893,8 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"正常\"}}\n\n", "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"正常\"}}\n\n",
} }
for _, e := range events { for _, e := range events {
w.Write([]byte(e)) _, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
@@ -1889,7 +1904,7 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
"stream": true, "stream": true,
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body)) req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
@@ -1902,5 +1917,7 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
assert.Contains(t, respBody, "正常") assert.Contains(t, respBody, "正常")
} }
var _ = fmt.Sprintf var (
var _ = time.Now _ = fmt.Sprintf
_ = time.Now
)

View File

@@ -7,16 +7,17 @@ import (
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"gorm.io/gorm"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/handler" "nex/backend/internal/handler"
"nex/backend/internal/handler/middleware" "nex/backend/internal/handler/middleware"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/internal/service" "nex/backend/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"gorm.io/gorm"
) )
func init() { func init() {
@@ -97,7 +98,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code) assert.Equal(t, 201, w.Code)
var createdModel domain.Model var createdModel domain.Model
json.Unmarshal(w.Body.Bytes(), &createdModel) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &createdModel))
assert.NotEmpty(t, createdModel.ID) assert.NotEmpty(t, createdModel.ID)
// 3. 列出 Provider // 3. 列出 Provider
@@ -106,7 +107,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var providers []domain.Provider var providers []domain.Provider
json.Unmarshal(w.Body.Bytes(), &providers) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &providers))
assert.Len(t, providers, 1) assert.Len(t, providers, 1)
assert.Equal(t, "sk-test-key", providers[0].APIKey) assert.Equal(t, "sk-test-key", providers[0].APIKey)
@@ -116,7 +117,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var models []domain.Model var models []domain.Model
json.Unmarshal(w.Body.Bytes(), &models) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &models))
assert.Len(t, models, 1) assert.Len(t, models, 1)
assert.Equal(t, "gpt-4", models[0].ModelName) assert.Equal(t, "gpt-4", models[0].ModelName)
@@ -163,7 +164,7 @@ func TestAnthropic_ModelCreation(t *testing.T) {
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code) assert.Equal(t, 201, w.Code)
var createdModel domain.Model var createdModel domain.Model
json.Unmarshal(w.Body.Bytes(), &createdModel) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &createdModel))
// 验证创建成功 // 验证创建成功
w = httptest.NewRecorder() w = httptest.NewRecorder()
@@ -194,9 +195,9 @@ func TestStats_RecordingAndQuery(t *testing.T) {
// 直接通过 repository 记录统计(模拟代理请求后的统计记录) // 直接通过 repository 记录统计(模拟代理请求后的统计记录)
statsRepo := repository.NewStatsRepository(db) statsRepo := repository.NewStatsRepository(db)
statsRepo.Record("p1", "gpt-4") require.NoError(t, statsRepo.Record("p1", "gpt-4"))
statsRepo.Record("p1", "gpt-4") require.NoError(t, statsRepo.Record("p1", "gpt-4"))
statsRepo.Record("p1", "gpt-4") require.NoError(t, statsRepo.Record("p1", "gpt-4"))
// 查询统计 // 查询统计
w = httptest.NewRecorder() w = httptest.NewRecorder()
@@ -205,7 +206,7 @@ func TestStats_RecordingAndQuery(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var stats []domain.UsageStats var stats []domain.UsageStats
json.Unmarshal(w.Body.Bytes(), &stats) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &stats))
assert.Len(t, stats, 1) assert.Len(t, stats, 1)
assert.Equal(t, 3, stats[0].RequestCount) assert.Equal(t, 3, stats[0].RequestCount)

View File

@@ -4,11 +4,11 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/config"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"nex/backend/internal/config"
) )
// setupTestDB 创建内存 SQLite 数据库并执行 AutoMigrate。 // setupTestDB 创建内存 SQLite 数据库并执行 AutoMigrate。

View File

@@ -3,9 +3,10 @@ package tests
import ( import (
"testing" "testing"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"nex/backend/internal/config"
) )
func TestMigration_ModelsUUIDPrimaryKey(t *testing.T) { func TestMigration_ModelsUUIDPrimaryKey(t *testing.T) {

View File

@@ -10,9 +10,10 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,9 +10,10 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -11,9 +11,10 @@ package mocks
import ( import (
context "context" context "context"
reflect "reflect"
conversion "nex/backend/internal/conversion" conversion "nex/backend/internal/conversion"
provider "nex/backend/internal/provider" provider "nex/backend/internal/provider"
reflect "reflect"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,9 +10,10 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,9 +10,10 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,9 +10,10 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,10 +10,11 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
time "time" time "time"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

View File

@@ -10,10 +10,11 @@
package mocks package mocks
import ( import (
domain "nex/backend/internal/domain"
reflect "reflect" reflect "reflect"
time "time" time "time"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )

5
lefthook.yml Normal file
View File

@@ -0,0 +1,5 @@
pre-commit:
commands:
backend-lint:
glob: "backend/**/*.go"
run: cd backend && go tool golangci-lint run --new-from-rev HEAD ./...

View File

@@ -1,2 +0,0 @@
schema: spec-driven
created: 2026-04-23

View File

@@ -1,130 +0,0 @@
## Context
后端项目使用 Go 1.26.2 开发,已集成 `golangci-lint v1.64.8` 作为 tool dependency`go.mod` 中声明),并通过 `make backend-lint` 调用。但当前**没有 `.golangci.yml` 配置文件**lint 以默认配置运行(仅启用极少量 linter且存在 embedfs 模块加载错误导致 lint 实际无法执行。
代码审计发现以下存量问题:
- 8 处 `err == sentinel` 应使用 `errors.Is()`
- 7 处 `_ = json.Marshal(...)` 忽略错误返回值
- 13 处 `zap.String("error", err.Error())` 应使用 `zap.Error(err)`
- 2 处 `fmt.Fprintf(os.Stderr, ...)` 应走 logger
- 1 处 `_ = io.ReadAll(...)` 忽略错误
## Goals / Non-Goals
**Goals:**
- 配置 golangci-lint将 README 中的编码规范转化为机器可检查的硬约束
- 引入 lefthook 实现 pre-commit 自动 lintAI 提交代码时自动拦截违规
- 修复存量代码中的规范违规
- 解决 embedfs 导致 lint 无法运行的阻塞问题
**Non-Goals:**
- 不引入自定义 linter 插件(开发成本过高)
- 不配置 CI pipeline lint 门禁(仅本地)
- 不改变现有错误响应策略(允许 err.Error() 暴露在 HTTP 响应中)
- 不引入 funlen lintergocyclo 已控制复杂度funlen 误报率高)
- 不引入 unparam linter项目 interface 密集unparam 误报率高)
## Decisions
### D1: Linter 选型 — 13 个 linter 分四层
```
🔒 硬约束层(项目规范 → 机器检查)
├── forbidigo 禁止 fmt.Print*/log.*/zap.L()/zap.S()
├── errorlint 强制 errors.Is/As禁止 err == 比较
└── errcheck 禁止忽略错误返回值check-blank: true
🏗️ 质量基线层
├── staticcheck Go 团队官方综合静态分析
├── revive golint 替代品,精选 8 条规则
│ (exported, var-naming, indent-error-flow, error-strings,
│ error-return, blank-imports, context-as-argument,
│ unexported-return)
├── gocritic 100+ 代码质量规则
└── gosec 安全检查
🛡️ 资源安全层
├── bodyclose HTTP 响应 Body 关闭检查
├── noctx HTTP 请求必须携带 context
└── nilerr 检查 if err != nil { return nil } 遗漏
📐 格式层(可自动修复)
├── gofumpt gofmt 的严格版
└── goimports import 排序local-prefixes: nex/backend
📊 复杂度
└── gocyclo 正式代码 ≤10 / 测试代码 ≤20
```
**替代方案**: 使用 revive 自定义规则检查 zap.String("error", err.Error()) → 决定暂不实施,先修复存量 13 处 + README 约定 + code review如果后续频繁违规再投入开发自定义规则。
### D2: forbidigo 配置策略
禁止列表:
- `fmt\.Print*` — 必须使用 zap logger
- `fmt\.Fprint*(os\.(Stdout|Stderr)` — 必须使用 zap logger
- `log\.(Print|Fatal|Panic|Printf)*` — 必须使用 zap logger
- `zap\.L()` — 必须通过 DI 注入 *zap.Logger
- `zap\.S()` — 不使用 Sugar logger
**不禁止**
- `zap.Logger.Fatal()` — main() 中的 Fatal 是合理的启动终止模式,且 forbidigo 按函数匹配不会拦截 zap.Logger 的方法调用
- `fmt.Sprintf` — 格式化字符串是合法用途
- `fmt.Errorf` — 创建带格式的错误是标准用法
### D3: 测试代码 vs 正式代码差异化规则
| 规则 | 正式代码 | 测试代码 |
|---|---|---|
| forbidigo | 全部启用 | 放宽fmt.Sprintf 等合理) |
| errcheck | check-blank: true | 放宽 check-blank |
| revive (exported) | 启用 | 排除 |
| gosec | 启用 | 排除 G101/G401/G501 |
| gocyclo | ≤10 | ≤20 |
通过 `issues.exclude-rules` 按路径 `*_test.go``tests/` 配置排除。
### D4: 生成代码排除
`tests/mocks/` 目录下的 8 个 mock 文件由 mockgen 生成,需要排除。使用 `issues.exclude-dirs` 配合 `issues.exclude-generated: true`golangci-lint 自动检测 `Code generated by` 标记)。
### D5: embedfs 阻塞问题修复
embedfs 模块使用 `//go:embed assets/*``//go:embed frontend-dist/*`,但这些目录在未构建时不存在,导致 golangci-lint 无法加载该模块。
**方案**: 在 `embedfs/assets/``embedfs/frontend-dist/` 中添加 `.gitkeep` 文件,使 `go:embed` 指令能匹配到内容。这是最小侵入的解决方案不影响正常构建流程desktop-build 会覆盖这些目录)。
### D6: lefthook pre-commit 配置
```yaml
# lefthook.yml
pre-commit:
commands:
backend-lint:
glob: "backend/**/*.go"
run: cd backend && go tool golangci-lint run --new-from-rev HEAD {staged_files}
```
关键设计:
- 只检查 staged 文件(`--new-from-rev HEAD`),速度快
- 只在 Go 文件变更时触发(`glob: "backend/**/*.go"`
- AI commit 时自动触发lint 不过则 commit 被拒绝,形成自动反馈循环
### D7: 存量修复策略
按优先级分批修复:
1. **P0 — embedfs 阻塞修复**:创建 .gitkeep 文件
2. **P1 — err == sentinel → errors.Is()**8 处,分布在 handler 和 client
3. **P2 — 忽略错误返回值**7 处 json.Marshal + 1 处 io.ReadAll + 2 处 stats Record加 nolint 注释)
4. **P3 — zap.String("error", err.Error()) → zap.Error(err)**13 处
5. **P4 — fmt.Fprintf(os.Stderr) → logger**2 处,在 cmd/desktop/
## Risks / Trade-offs
**[lint 速度影响 commit 体验]** → lefthook 只检查 staged Go 文件,增量检查通常 <5 秒,可接受。如果仍慢,可加 `--timeout` 限制。
**[lefthook 是新依赖]** → lefthook 是开发工具依赖,不影响生产代码。作为单二进制分发,安装简单(`go install` 或从 GitHub release 下载)。首次需要开发者手动安装。
**[存量修复可能引入新 bug]** → 所有修复都是机械性替换errors.Is、zap.Error 等),不改变逻辑。修复后运行 `make test` 确认无回归。
**[forbidigo 可能误拦合理用法]** → 通过仔细配置允许列表(允许 fmt.Sprintf、fmt.Errorf并在发现误报时调整规则。

View File

@@ -1,29 +0,0 @@
## Why
项目复杂度增长后AI 编写代码时经常忽略基本编码规范(如使用指定日志工具、正确处理错误等)。依赖 prompt 约定是"软约束",无法可靠防止违规。需要引入静态分析工具,将编码规范从"约定"升级为"机器可检查的硬约束",在提交时自动拦截问题代码。
## What Changes
- 新增 `.golangci.yml` 配置文件,启用 13 个 linter 并配置项目专属规则
- 引入 lefthook 作为 Git hook 管理器,在 pre-commit 时自动运行 lint
- 修复存量代码中的规范违规(约 31 处)
- 解决 embedfs 模块导致 golangci-lint 无法运行的阻塞问题
- 更新 README.md 补充代码规范说明
## Capabilities
### New Capabilities
- `code-lint`: 后端代码静态分析规则配置,包括 13 个 linter 的启用、参数配置、测试/正式代码的差异化规则、生成代码排除等
- `pre-commit-hook`: 基于 lefthook 的 pre-commit hook 配置,提交时自动运行 lint 检查
### Modified Capabilities
- `module-logging`: 新增 zap.Error(err) 优于 zap.String("error", err.Error()) 的规范要求
- `error-handling`: 新增必须使用 errors.Is/As 而非直接 == 比较的强制要求
- `structured-logging`: 补充 zap.Error(err) 的使用约定
## Impact
- 新增开发依赖lefthook二进制工具不影响生产代码
- 修改文件:约 15 个 Go 源文件存量修复、README.md、Makefile
- 新增文件:`.golangci.yml``lefthook.yml``embedfs/assets/.gitkeep``embedfs/frontend-dist/.gitkeep`
- 开发流程影响git commit 时自动触发 lint 检查lint 不过则提交被拒绝

View File

@@ -1,174 +0,0 @@
# Code Lint
## Purpose
定义后端 Go 代码静态分析规则,将编码规范从人工约定升级为机器可检查的硬约束,通过 golangci-lint 在开发和提交阶段自动拦截违规代码。
## ADDED Requirements
### Requirement: golangci-lint 配置
系统 SHALL 通过 `.golangci.yml` 配置 golangci-lint启用 13 个 linter。
#### Scenario: 配置文件位置
- **WHEN** 配置 lint 规则
- **THEN** 配置文件 SHALL 位于 `backend/.golangci.yml`
#### Scenario: 启用的 linter 列表
- **WHEN** 运行 golangci-lint
- **THEN** SHALL 启用以下 linterforbidigo、errorlint、errcheck、staticcheck、revive、gocritic、gosec、bodyclose、noctx、nilerr、gofumpt、goimports、gocyclo
### Requirement: forbidigo 日志输出约束
系统 SHALL 通过 forbidigo 禁止在正式代码中使用直接输出函数。
#### Scenario: 禁止 fmt.Print 系列
- **WHEN** 正式代码中调用 fmt.Print、fmt.Println、fmt.Printf
- **THEN** lint SHALL 报错,提示使用 zap logger
#### Scenario: 禁止 fmt.Fprint 到 Stdout/Stderr
- **WHEN** 正式代码中调用 fmt.Fprintf(os.Stdout, ...) 或 fmt.Fprintf(os.Stderr, ...)
- **THEN** lint SHALL 报错,提示使用 zap logger
#### Scenario: 禁止标准库 log
- **WHEN** 正式代码中调用 log.Print、log.Fatal、log.Panic、log.Printf 等
- **THEN** lint SHALL 报错,提示使用 zap logger
#### Scenario: 禁止 zap.L() 全局 logger
- **WHEN** 正式代码中调用 zap.L()
- **THEN** lint SHALL 报错,提示通过 DI 注入 *zap.Logger
#### Scenario: 禁止 zap.S() Sugar logger
- **WHEN** 代码中调用 zap.S()
- **THEN** lint SHALL 报错,不使用 Sugar logger
#### Scenario: 允许 fmt.Sprintf 和 fmt.Errorf
- **WHEN** 代码中使用 fmt.Sprintf 或 fmt.Errorf
- **THEN** lint SHALL NOT 报错
#### Scenario: 测试代码放宽
- **WHEN** 测试文件(*_test.go或 tests/ 目录中使用 fmt.Print 系列
- **THEN** forbidigo SHALL NOT 报错
### Requirement: errorlint 错误比较约束
系统 SHALL 通过 errorlint 强制使用类型安全的错误比较方式。
#### Scenario: 禁止 err == sentinel 比较
- **WHEN** 代码中使用 `err == someError` 直接比较错误
- **THEN** lint SHALL 报错,要求使用 errors.Is()
#### Scenario: 禁止直接类型断言
- **WHEN** 代码中使用 `err.(SomeType)` 直接类型断言
- **THEN** lint SHALL 报错,要求使用 errors.As()
### Requirement: errcheck 错误返回值检查
系统 SHALL 通过 errcheck 禁止忽略函数返回的错误。
#### Scenario: 启用 check-blank
- **WHEN** 代码中使用 `_ = someFuncReturnsError()`
- **THEN** lint SHALL 报错(除非排除列表中的函数)
#### Scenario: 启用 check-type-assertions
- **WHEN** 代码中使用未检查的类型断言 `v := x.(Type)`
- **THEN** lint SHALL 报错
#### Scenario: 排除 fmt.Fprintf
- **WHEN** 代码中忽略 fmt.Fprintf 的返回值
- **THEN** errcheck SHALL NOT 报错io.Writer 场景合理)
#### Scenario: 测试代码放宽
- **WHEN** 测试文件中忽略错误返回值
- **THEN** errcheck 的 check-blank SHALL 放宽
### Requirement: revive 代码风格规则
系统 SHALL 通过 revive 启用精选的 8 条代码风格规则。
#### Scenario: 启用的规则
- **WHEN** 运行 revive
- **THEN** SHALL 启用exported、var-naming、indent-error-flow、error-strings、error-return、blank-imports、context-as-argument、unexported-return
#### Scenario: 测试代码排除 exported
- **WHEN** 测试文件中的导出符号缺少文档注释
- **THEN** revive SHALL NOT 报错
### Requirement: gosec 安全检查
系统 SHALL 通过 gosec 检查常见安全问题。
#### Scenario: 正式代码全部启用
- **WHEN** 正式代码中存在安全隐患硬编码凭证、SQL 注入等)
- **THEN** gosec SHALL 报错
#### Scenario: 测试代码排除部分规则
- **WHEN** 测试文件中触发 G101硬编码密钥、G401/G501弱密码算法
- **THEN** gosec SHALL NOT 报错
### Requirement: gocyclo 圈复杂度控制
系统 SHALL 通过 gocyclo 控制函数复杂度。
#### Scenario: 正式代码复杂度阈值
- **WHEN** 正式代码中函数圈复杂度超过 10
- **THEN** gocyclo SHALL 报错
#### Scenario: 测试代码复杂度阈值
- **WHEN** 测试代码中函数圈复杂度超过 20
- **THEN** gocyclo SHALL 报错
### Requirement: goimports import 排序
系统 SHALL 通过 goimports 统一 import 分组排序。
#### Scenario: 三组格式
- **WHEN** 格式化 import
- **THEN** SHALL 按标准库、第三方库、本地包nex/backend三组排序
- **THEN** local-prefixes SHALL 配置为 nex/backend
### Requirement: 生成代码排除
系统 SHALL 排除自动生成的代码的 lint 检查。
#### Scenario: mocks 目录排除
- **WHEN** lint 扫描 tests/mocks/ 目录
- **THEN** SHALL 排除该目录(由 mockgen 生成的代码)
#### Scenario: Code generated 标记自动检测
- **WHEN** 文件包含 `// Code generated by` 标记
- **THEN** golangci-lint SHALL 自动排除该文件
### Requirement: embedfs 编译兼容
系统 SHALL 确保 golangci-lint 能正常加载 embedfs 模块。
#### Scenario: 空目录占位
- **WHEN** embedfs 模块的 assets/ 和 frontend-dist/ 目录不存在
- **THEN** SHALL 通过 .gitkeep 文件确保目录存在
- **THEN** go:embed 指令 SHALL 能正常匹配

View File

@@ -1,45 +0,0 @@
# Error Handling — Delta
## MODIFIED Requirements
### Requirement: 使用类型安全错误判断
系统 SHALL 使用类型安全方式判断错误类型,并通过 lint 工具强制执行。
#### Scenario: 数据库错误判断
- **WHEN** 判断数据库唯一约束错误
- **THEN** SHALL 使用 errors.Is(err, gorm.ErrDuplicatedKey)
- **THEN** SHALL NOT 使用字符串匹配 err.Error()
#### Scenario: 网络错误判断
- **WHEN** 判断网络错误
- **THEN** SHALL 使用 errors.As(err, &net.Error) 判断网络错误
- **THEN** SHALL 使用 errors.As(err, &net.OpError) 判断操作错误
- **THEN** SHALL 使用 errors.Is(opErr.Err, syscall.ECONNRESET) 判断连接重置
- **THEN** SHALL NOT 使用字符串匹配判断错误类型
#### Scenario: 错误链判断
- **WHEN** 判断错误链中的特定错误
- **THEN** SHALL 使用 errors.Is 进行链式判断
- **THEN** SHALL 使用 errors.As 提取特定类型错误
#### Scenario: lint 自动拦截错误比较
- **WHEN** 代码中使用 `err == someError` 直接比较
- **THEN** errorlint SHALL 检测并报错
- **THEN** SHALL 改用 errors.Is()
#### Scenario: lint 自动拦截类型断言
- **WHEN** 代码中使用 `err.(SomeType)` 直接类型断言
- **THEN** errorlint SHALL 检测并报错
- **THEN** SHALL 改用 errors.As()
#### Scenario: lint 自动拦截忽略错误返回值
- **WHEN** 代码中使用 `_ = funcReturnsError()` 忽略错误
- **THEN** errcheck SHALL 检测并报错
- **THEN** SHALL 正确处理错误或添加 //nolint:errcheck 注释(仅在有意忽略时)

View File

@@ -1,32 +0,0 @@
# Module Logging — Delta
## MODIFIED Requirements
### Requirement: 禁止全局 logger
系统 SHALL 禁止在业务代码中使用全局 logger并通过 lint 工具强制执行。
#### Scenario: 移除 zap.L() 调用
- **WHEN** 重构现有代码
- **THEN** SHALL 移除所有 `zap.L()` 调用
- **THEN** SHALL 通过构造函数注入 logger
- **THEN** 允许仅在测试代码中使用 `zap.L()``zap.NewNop()`
#### Scenario: 移除 zap.L() fallback
- **WHEN** 构造函数 logger 参数为 nil
- **THEN** SHALL NOT 使用 `zap.L()` 作为默认值
- **THEN** 调用方 SHALL 必须传入有效的 logger
#### Scenario: lint 自动拦截 zap.L()
- **WHEN** 正式代码中新增 `zap.L()` 调用
- **THEN** forbidigo SHALL 检测并报错
- **THEN** git commit SHALL 被拒绝
#### Scenario: 禁止 fmt/os.Stderr 直接输出
- **WHEN** 正式代码中使用 fmt.Print*、fmt.Fprintf(os.Stderr, ...) 等直接输出
- **THEN** forbidigo SHALL 检测并报错
- **THEN** SHALL 使用注入的 zap logger 替代

View File

@@ -1,54 +0,0 @@
# Pre-commit Hook
## Purpose
定义基于 lefthook 的 pre-commit hook 配置,在 git commit 时自动运行 lint 检查,拦截违规代码提交。
## ADDED Requirements
### Requirement: lefthook 配置
系统 SHALL 通过 `lefthook.yml` 配置 pre-commit hook。
#### Scenario: 配置文件位置
- **WHEN** 配置 lefthook
- **THEN** 配置文件 SHALL 位于项目根目录 `lefthook.yml`
#### Scenario: pre-commit hook 安装
- **WHEN** 开发者首次克隆项目
- **THEN** 运行 `lefthook install` SHALL 安装 git hooks
- **THEN** hooks SHALL 自动注册到 .git/hooks/
### Requirement: Go 文件变更触发 lint
系统 SHALL 在 Go 文件变更时自动运行 golangci-lint。
#### Scenario: 检测到 Go 文件变更
- **WHEN** git commit 中包含 backend/**/*.go 文件的变更
- **THEN** SHALL 自动运行 golangci-lint
#### Scenario: 增量检查
- **WHEN** 运行 lint
- **THEN** SHALL 只检查 staged 文件(使用 --new-from-rev HEAD
- **THEN** SHALL NOT 检查整个代码库
#### Scenario: lint 通过
- **WHEN** golangci-lint 检查通过
- **THEN** commit SHALL 正常完成
#### Scenario: lint 失败
- **WHEN** golangci-lint 检查发现违规
- **THEN** commit SHALL 被拒绝
- **THEN** SHALL 显示具体的违规信息和修复建议
#### Scenario: 无 Go 文件变更
- **WHEN** git commit 不包含 Go 文件变更
- **THEN** SHALL NOT 运行 golangci-lint
- **THEN** commit SHALL 正常完成

View File

@@ -1,30 +0,0 @@
# Structured Logging — Delta
## MODIFIED Requirements
### Requirement: 字段标准化
系统 SHALL 使用标准化字段定义,并通过 lint 工具强制执行错误字段规范。
#### Scenario: 标准字段常量
- **WHEN** 记录日志字段
- **THEN** SHALL 使用 `pkg/logger/field.go` 中定义的常量
- **THEN** 字段名 SHALL 包括:`request_id``provider_id``model_name``method``path``status``latency`
#### Scenario: 错误字段统一
- **WHEN** 记录错误日志
- **THEN** SHALL 使用 `zap.Error(err)`
- **THEN** SHALL NOT 使用 `zap.String("error", err.Error())`
#### Scenario: lint 强化错误字段约束
- **WHEN** 存量代码中使用 `zap.String("error", err.Error())` 记录错误
- **THEN** SHALL 修改为 `zap.Error(err)`
#### Scenario: 字段构造函数
- **WHEN** 构造日志字段
- **THEN** SHALL 优先使用 `pkg/logger` 提供的辅助函数
- **THEN** 辅助函数 SHALL 返回 `zap.Field` 类型

View File

@@ -1,42 +0,0 @@
## 1. 基础设施
- [x] 1.1 创建 embedfs/assets/.gitkeep 和 embedfs/frontend-dist/.gitkeep解决 embedfs 编译阻塞
- [x] 1.2 创建 backend/.golangci.yml 配置文件,启用 13 个 linter 并配置所有规则forbidigo、errorlint、errcheck、staticcheck、revive、gocritic、gosec、bodyclose、noctx、nilerr、gofumpt、goimports、gocyclo
- [x] 1.3 配置 .golangci.yml 中测试代码差异化规则exclude-rules for *_test.go and tests/
- [x] 1.4 配置 .golangci.yml 排除生成代码exclude-dirs: tests/mocks, exclude-generated: true
- [ ] 1.5 运行 make backend-lint 验证配置可正常执行(无 embedfs 报错)
- [x] 1.6 创建 lefthook.yml 配置文件,配置 pre-commit hook 仅检查 staged Go 文件
- [ ] 1.7 运行 lefthook install 安装 git hooks 并验证 hook 生效
## 2. 存量代码修复 — 错误比较
- [x] 2.1 修复 internal/handler/model_handler.go 中 4 处 err == sentinel → errors.Is()
- [x] 2.2 修复 internal/handler/provider_handler.go 中 4 处 err == sentinel → errors.Is()
- [x] 2.3 修复 internal/provider/client.go:223 err == io.EOF → errors.Is(err, io.EOF)
## 3. 存量代码修复 — 忽略错误返回值
- [x] 3.1 修复 internal/conversion/openai/adapter.go 中 3 处 _ = json.Marshal → 正确处理错误
- [x] 3.2 修复 internal/conversion/anthropic/adapter.go 中 2 处 _ = json.Marshal → 正确处理错误
- [x] 3.3 修复 internal/conversion/anthropic/decoder.go 中 1 处 _ = json.Marshal → 正确处理错误
- [x] 3.4 修复 internal/conversion/engine.go:394 _ = json.Marshal → 正确处理错误fallback 场景)
- [x] 3.5 修复 internal/provider/client.go:144 _ = io.ReadAll → 正确处理错误
- [x] 3.6 为 internal/handler/proxy_handler.go 中 2 处 _ = statsService.Record 添加 //nolint:errcheck 注释goroutine fire-and-forget 模式)
## 4. 存量代码修复 — 日志字段
- [x] 4.1 修复 internal/handler/proxy_handler.go 中 zap.String("error", err.Error()) → zap.Error(err)(约 6 处)
- [x] 4.2 修复 internal/provider/client.go:187 zap.String("error", err.Error()) → zap.Error(err)
- [x] 4.3 修复 internal/conversion/engine.go 中 zap.String("error", err.Error()) → zap.Error(err)(约 6 处)
## 5. 存量代码修复 — 桌面端日志
- [x] 5.1 修复 cmd/desktop/dialog_linux.go 中 2 处 fmt.Fprintf(os.Stderr, ...) → 改用 zap logger
## 6. 验证与文档
- [ ] 6.1 运行 make backend-lint 确认所有 linter 通过
- [ ] 6.2 运行 make backend-test 确认所有测试通过
- [x] 6.3 更新 backend/README.md 编码规范部分:补充 zap.Error(err) 优先于 zap.String("error", err.Error()) 的规范
- [x] 6.4 更新 backend/README.md 编码规范部分:补充强制使用 errors.Is/As 而非 == 比较的说明
- [x] 6.5 更新 README.md 添加 lefthook 安装说明(首次克隆项目后需执行 lefthook install