1
0

chore: 合并 dev-code-backend-format 到 master

This commit is contained in:
2026-04-24 18:20:12 +08:00
91 changed files with 1322 additions and 824 deletions

1
.gitignore vendored
View File

@@ -401,6 +401,7 @@ cython_debug/
# Custom
.claude
.opencode
.codex
openspec/changes/archive
temp
.agents

View File

@@ -1,11 +1,11 @@
.PHONY: all dev build test lint clean \
backend-build backend-run backend-dev backend-test backend-test-unit backend-test-integration backend-test-coverage \
backend-build backend-run backend-dev backend-test backend-test-all backend-test-unit backend-test-integration backend-test-coverage \
backend-lint backend-clean backend-deps backend-generate \
backend-db-up backend-db-down backend-db-status backend-db-create \
test-mysql-up test-mysql-down test-mysql test-mysql-quick \
frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint frontend-clean \
desktop-build desktop-build-mac desktop-build-win desktop-build-linux \
desktop-dev desktop-package-mac desktop-package-win desktop-package-linux desktop-clean \
desktop-dev desktop-test desktop-package-mac desktop-package-win desktop-package-linux desktop-clean \
desktop-prepare-frontend desktop-prepare-embedfs
# ============================================
@@ -19,7 +19,7 @@ dev:
build: backend-build frontend-build
@echo "✅ Build complete"
test: backend-test frontend-test
test: backend-test desktop-test frontend-test
@echo "✅ All tests passed"
lint: backend-lint frontend-lint
@@ -41,6 +41,9 @@ backend-dev:
cd backend && go run ./cmd/server
backend-test:
cd backend && go test ./internal/... ./pkg/... ./tests/... ./cmd/server/... -v
backend-test-all:
cd backend && go test ./... -v
backend-test-unit:
@@ -179,6 +182,9 @@ desktop-dev: desktop-prepare-frontend desktop-prepare-embedfs
@echo "🖥️ Starting desktop app in dev mode..."
cd backend && go run ./cmd/desktop
desktop-test:
cd backend && go test ./cmd/desktop/... -v
desktop-package-mac:
./scripts/build/package-macos.sh

View File

@@ -294,6 +294,9 @@ make frontend-test-coverage # 前端覆盖率
## 开发
```bash
# 首次克隆后安装 Git hooks
lefthook install
# 顶层便捷命令
make dev # 启动开发环境(并行启动后端和前端)
make build # 构建所有产物

91
backend/.golangci.yml Normal file
View File

@@ -0,0 +1,91 @@
run:
timeout: 5m
tests: true
linters:
disable-all: true
enable:
- forbidigo
- errorlint
- errcheck
- staticcheck
- revive
- gocritic
- gosec
- bodyclose
- noctx
- nilerr
- goimports
- gocyclo
linters-settings:
errcheck:
check-blank: true
check-type-assertions: true
exclude-functions:
- fmt.Fprintf
forbidigo:
analyze-types: true
forbid:
- p: '^fmt\.Print.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^fmt\.Fprint.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^log\.(Print|Println|Printf|Fatal|Fatalln|Fatalf|Panic|Panicln|Panicf)$'
msg: 使用 zap logger不要使用标准库 log
- p: '^zap\.L$'
msg: 通过依赖注入传递 *zap.Logger不要使用全局 logger
- p: '^zap\.S$'
msg: 不使用 Sugar logger
revive:
rules:
- name: exported
- name: var-naming
- name: indent-error-flow
- name: error-strings
- name: error-return
- name: blank-imports
- name: context-as-argument
- name: unexported-return
goimports:
local-prefixes: nex/backend
gocyclo:
min-complexity: 10
issues:
exclude-dirs:
- tests/mocks
exclude-generated: true
exclude-rules:
- path: '(_test\.go|tests/)'
linters:
- forbidigo
- path: '(_test\.go|tests/)'
linters:
- errcheck
source: '(^\s*_\s*=|,\s*_)'
- path: 'tests/integration/e2e_conversion_test\.go'
linters:
- errcheck
- path: '(_test\.go|tests/)'
linters:
- revive
text: '^exported:'
- path: '(_test\.go|tests/)'
linters:
- gosec
text: 'G(101|401|501)'
- path: '(_test\.go|tests/)'
linters:
- gocyclo
text: 'cyclomatic complexity (1[1-9]|20) of .* is high \(> 10\)'
- linters:
- revive
text: '(that stutters|BuildUrl should be BuildURL|ConvertHttpRequest should be ConvertHTTPRequest|ConvertHttpResponse should be ConvertHTTPResponse)'
- path: 'internal/conversion/.*\.go'
linters:
- gocyclo
- gocritic
- path: '(internal/provider/client\.go|internal/service/model_service_impl\.go|internal/service/stats_buffer\.go|internal/handler/proxy_handler\.go|cmd/(desktop|server)/main\.go)'
linters:
- gocyclo

View File

@@ -609,6 +609,7 @@ err := v.Validate(myStruct)
- **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配lint 强约束errorlint 禁止 `err ==` 直接比较和 `err.(*T)` 直接断言)
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
- **日志 error 字段**:使用 `zap.Error(err)`,不使用 `zap.String("error", err.Error())` 手工字符串化
- **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片

View File

@@ -6,19 +6,25 @@ import (
"fmt"
"os/exec"
"strings"
"go.uber.org/zap"
)
func showError(title, message string) {
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`,
escapeAppleScript(message), escapeAppleScript(title))
exec.Command("osascript", "-e", script).Run()
if err := exec.Command("osascript", "-e", script).Run(); err != nil {
dialogLogger().Warn("显示错误对话框失败", zap.Error(err))
}
}
func showAbout() {
message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway"
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "关于 Nex Gateway"`,
escapeAppleScript(message))
exec.Command("osascript", "-e", script).Run()
if err := exec.Command("osascript", "-e", script).Run(); err != nil {
dialogLogger().Warn("显示关于对话框失败", zap.Error(err))
}
}
func escapeAppleScript(s string) string {

View File

@@ -4,7 +4,6 @@ package main
import (
"fmt"
"os"
"os/exec"
"sync"
)
@@ -63,7 +62,7 @@ func showError(title, message string) {
exec.Command("xmessage", "-center",
fmt.Sprintf("%s: %s", title, message)).Run()
default:
fmt.Fprintf(os.Stderr, "错误: %s: %s\n", title, message)
dialogLogger().Error("无法显示错误对话框")
}
}
@@ -83,6 +82,6 @@ func showAbout() {
exec.Command("xmessage", "-center",
fmt.Sprintf("关于 Nex Gateway: %s", message)).Run()
default:
fmt.Fprintf(os.Stderr, "关于 Nex Gateway: %s\n", message)
dialogLogger().Info("关于 Nex Gateway")
}
}

View File

@@ -0,0 +1,15 @@
package main
import (
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
)
func dialogLogger() *zap.Logger {
if zapLogger != nil {
return zapLogger
}
return pkgLogger.NewMinimal()
}

View File

@@ -13,10 +13,7 @@ import (
"strings"
"time"
"github.com/getlantern/systray"
"github.com/gin-gonic/gin"
"github.com/gofrs/flock"
"go.uber.org/zap"
"nex/embedfs"
"nex/backend/internal/config"
"nex/backend/internal/conversion"
@@ -28,9 +25,13 @@ import (
"nex/backend/internal/provider"
"nex/backend/internal/repository"
"nex/backend/internal/service"
pkgLogger "nex/backend/pkg/logger"
"nex/embedfs"
"github.com/getlantern/systray"
"github.com/gin-gonic/gin"
"github.com/gofrs/flock"
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
)
var (
@@ -51,12 +52,16 @@ func main() {
showError("Nex Gateway", "已有 Nex 实例运行")
os.Exit(1)
}
defer singleLock.Unlock()
defer func() {
if err := singleLock.Unlock(); err != nil {
minimalLogger.Warn("释放实例锁失败", zap.Error(err))
}
}()
if err := checkPortAvailable(port); err != nil {
minimalLogger.Error("端口不可用", zap.Error(err))
showError("Nex Gateway", err.Error())
os.Exit(1)
return
}
cfg, err := config.LoadConfig()
@@ -75,7 +80,11 @@ func main() {
if err != nil {
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
}
defer zapLogger.Sync()
defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
cfg.PrintSummary(zapLogger)
@@ -144,14 +153,14 @@ func main() {
go func() {
zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr))
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error()))
zapLogger.Fatal("服务器启动失败", zap.Error(err))
}
}()
go func() {
time.Sleep(500 * time.Millisecond)
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
zapLogger.Warn("无法打开浏览器", zap.String("error", err.Error()))
zapLogger.Warn("无法打开浏览器", zap.Error(err))
}
}()
@@ -193,7 +202,7 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
func setupStaticFiles(r *gin.Engine) {
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
if err != nil {
zapLogger.Fatal("无法加载前端资源", zap.String("error", err.Error()))
zapLogger.Fatal("无法加载前端资源", zap.Error(err))
}
getContentType := func(path string) string {
@@ -266,7 +275,7 @@ func setupSystray(port int) {
icon, err = embedfs.Assets.ReadFile("assets/icon.png")
}
if err != nil {
zapLogger.Error("无法加载托盘图标", zap.String("error", err.Error()))
zapLogger.Error("无法加载托盘图标", zap.Error(err))
}
systray.SetIcon(icon)
systray.SetTitle("Nex Gateway")
@@ -287,7 +296,9 @@ func setupSystray(port int) {
for {
select {
case <-mOpen.ClickedCh:
openBrowser(fmt.Sprintf("http://localhost:%d", port))
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
zapLogger.Warn("打开浏览器失败", zap.Error(err))
}
case <-mAbout.ClickedCh:
showAbout()
case <-mQuit.ClickedCh:
@@ -308,7 +319,9 @@ func doShutdown() {
if server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
server.Shutdown(ctx)
if err := server.Shutdown(ctx); err != nil && zapLogger != nil {
zapLogger.Warn("关闭服务器失败", zap.Error(err))
}
}
if shutdownCancel != nil {
@@ -346,8 +359,8 @@ func (s *SingletonLock) Lock() error {
return nil
}
func (s *SingletonLock) Unlock() {
s.flock.Unlock()
func (s *SingletonLock) Unlock() error {
return s.flock.Unlock()
}
func openBrowser(url string) error {

View File

@@ -21,7 +21,7 @@ func TestCheckPortAvailable(t *testing.T) {
func TestCheckPortOccupied(t *testing.T) {
port := 19827
listener, err := net.Listen("tcp", ":19827")
listener, err := net.Listen("tcp", "127.0.0.1:19827")
if err != nil {
t.Fatalf("无法启动测试服务器: %v", err)
}
@@ -47,13 +47,19 @@ func TestCheckPortOccupied(t *testing.T) {
func TestCheckPortAvailableAfterClose(t *testing.T) {
port := 19828
listener, err := net.Listen("tcp", ":19828")
listener, err := net.Listen("tcp", "127.0.0.1:19828")
if err != nil {
t.Fatalf("无法启动测试服务器: %v", err)
}
server := &http.Server{}
go server.Serve(listener)
server := &http.Server{ReadHeaderTimeout: time.Second}
defer server.Close()
go func() {
err := server.Serve(listener)
if err != nil && err != http.ErrServerClosed {
t.Errorf("serve failed: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)

View File

@@ -14,7 +14,11 @@ func TestSingletonLock_FirstLockSuccess(t *testing.T) {
if err := lock.Lock(); err != nil {
t.Fatalf("首次加锁应成功,但返回错误: %v", err)
}
defer lock.Unlock()
defer func() {
if err := lock.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
}
func TestSingletonLock_DuplicateLockFails(t *testing.T) {
@@ -25,12 +29,18 @@ func TestSingletonLock_DuplicateLockFails(t *testing.T) {
if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err)
}
defer lock1.Unlock()
defer func() {
if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
lock2 := NewSingletonLock(lockPath)
err := lock2.Lock()
if err == nil {
lock2.Unlock()
if unlockErr := lock2.Unlock(); unlockErr != nil {
t.Fatalf("解锁失败: %v", unlockErr)
}
t.Fatal("重复加锁应失败,但返回 nil")
}
}
@@ -43,16 +53,22 @@ func TestSingletonLock_UnlockThenRelock(t *testing.T) {
if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err)
}
lock1.Unlock()
if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
lock2 := NewSingletonLock(lockPath)
if err := lock2.Lock(); err != nil {
t.Fatalf("释放后重新加锁应成功: %v", err)
}
lock2.Unlock()
if err := lock2.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}
func TestSingletonLock_UnlockWithoutLock(t *testing.T) {
lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock"))
lock.Unlock()
if err := lock.Unlock(); err != nil {
t.Fatalf("未加锁时解锁失败: %v", err)
}
}

View File

@@ -6,9 +6,9 @@ import (
"strings"
"testing"
"github.com/gin-gonic/gin"
"nex/embedfs"
"github.com/gin-gonic/gin"
)
func TestSetupStaticFiles(t *testing.T) {

View File

@@ -44,7 +44,11 @@ func main() {
if err != nil {
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
}
defer zapLogger.Sync()
defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
cfg.PrintSummary(zapLogger)

View File

@@ -1,6 +1,7 @@
package config
import (
"errors"
"fmt"
"os"
"path/filepath"
@@ -58,7 +59,10 @@ type LogConfig struct {
// DefaultConfig returns default config values
func DefaultConfig() *Config {
// Use home dir for default paths
homeDir, _ := os.UserHomeDir()
homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex")
return &Config{
@@ -97,7 +101,7 @@ func GetConfigDir() (string, error) {
return "", err
}
configDir := filepath.Join(homeDir, ".nex")
if err := os.MkdirAll(configDir, 0755); err != nil {
if err := os.MkdirAll(configDir, 0o755); err != nil {
return "", err
}
return configDir, nil
@@ -123,7 +127,10 @@ func GetConfigPath() (string, error) {
// setupDefaults 设置默认配置值
func setupDefaults(v *viper.Viper) {
homeDir, _ := os.UserHomeDir()
homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex")
v.SetDefault("server.port", 9826)
@@ -177,27 +184,33 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
// 绑定所有 flag 到 viper
// 注意:必须在设置默认值之后绑定
v.BindPFlag("server.port", flagSet.Lookup("server-port"))
v.BindPFlag("server.read_timeout", flagSet.Lookup("server-read-timeout"))
v.BindPFlag("server.write_timeout", flagSet.Lookup("server-write-timeout"))
bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
v.BindPFlag("database.driver", flagSet.Lookup("database-driver"))
v.BindPFlag("database.path", flagSet.Lookup("database-path"))
v.BindPFlag("database.host", flagSet.Lookup("database-host"))
v.BindPFlag("database.port", flagSet.Lookup("database-port"))
v.BindPFlag("database.user", flagSet.Lookup("database-user"))
v.BindPFlag("database.password", flagSet.Lookup("database-password"))
v.BindPFlag("database.dbname", flagSet.Lookup("database-dbname"))
v.BindPFlag("database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
v.BindPFlag("database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
v.BindPFlag("database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
v.BindPFlag("log.level", flagSet.Lookup("log-level"))
v.BindPFlag("log.path", flagSet.Lookup("log-path"))
v.BindPFlag("log.max_size", flagSet.Lookup("log-max-size"))
v.BindPFlag("log.max_backups", flagSet.Lookup("log-max-backups"))
v.BindPFlag("log.max_age", flagSet.Lookup("log-max-age"))
v.BindPFlag("log.compress", flagSet.Lookup("log-compress"))
bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
}
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
if err := v.BindPFlag(key, flag); err != nil {
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
}
}
// setupEnv 绑定环境变量
@@ -218,10 +231,17 @@ func setupConfigFile(v *viper.Viper, configPath string) error {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
// 配置文件不存在,创建默认配置文件
if err := v.SafeWriteConfig(); err != nil {
// 忽略写入错误(可能目录已存在等)
writeErr := v.SafeWriteConfigAs(configPath)
if writeErr == nil {
return nil
}
var alreadyExistsErr viper.ConfigFileAlreadyExistsError
if errors.As(writeErr, &alreadyExistsErr) {
return nil
}
return appErrors.Wrap(appErrors.ErrInternal, writeErr)
}
return nil
}
@@ -246,7 +266,9 @@ func LoadConfigFromPath(configPath string) (*Config, error) {
setupFlags(v, flagSet)
// 3. 解析 CLI 参数(忽略错误,因为可能没有参数)
flagSet.Parse(os.Args[1:])
if err := flagSet.Parse(os.Args[1:]); err != nil {
return nil, appErrors.Wrap(appErrors.ErrInvalidRequest, err)
}
// 4. 获取配置文件路径(可能被 --config 参数覆盖)
if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" {
@@ -295,11 +317,11 @@ func SaveConfig(cfg *Config) error {
// Ensure directory exists
dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil {
if err := os.MkdirAll(dir, 0o755); err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
return os.WriteFile(configPath, data, 0600)
return os.WriteFile(configPath, data, 0o600)
}
// Validate validates the config

View File

@@ -236,7 +236,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
configPath := filepath.Join(dir, "config.yaml")
data, err := yaml.Marshal(cfg)
require.NoError(t, err)
err = os.WriteFile(configPath, data, 0644)
err = os.WriteFile(configPath, data, 0o600)
require.NoError(t, err)
// 加载配置

View File

@@ -6,15 +6,15 @@ import (
// Provider 供应商模型
type Provider struct {
ID string `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"`
APIKey string `gorm:"not null" json:"api_key"`
BaseURL string `gorm:"not null" json:"base_url"`
Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"`
Enabled bool `gorm:"default:true" json:"enabled"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
ID string `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"`
APIKey string `gorm:"not null" json:"api_key"`
BaseURL string `gorm:"not null" json:"base_url"`
Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"`
Enabled bool `gorm:"default:true" json:"enabled"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
}
// Model 模型配置id 为 UUID 自动生成UNIQUE(provider_id, model_name)
@@ -47,4 +47,3 @@ func (Model) TableName() string {
func (UsageStats) TableName() string {
return "usage_stats"
}

View File

@@ -141,7 +141,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Message: err.Message,
},
}
body, _ := json.Marshal(errMsg)
body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"type":"error","error":{"type":"internal_error","message":"internal error"}}`), statusCode
}
return body, statusCode
}
@@ -235,7 +238,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
return "", nil, err
}
rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel)
encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m)
}
return current, rewriteFunc, nil
@@ -269,7 +276,11 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
switch ifaceType {
case conversion.InterfaceTypeChat:
// Chat 响应必须有 model 字段,存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel)
encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m)
default:
return body, nil

View File

@@ -2,6 +2,7 @@ package anthropic
import (
"encoding/json"
"errors"
"testing"
"nex/backend/internal/conversion"
@@ -52,10 +53,10 @@ func TestAdapter_BuildUrl(t *testing.T) {
a := NewAdapter()
tests := []struct {
name string
nativePath string
name string
nativePath string
interfaceType conversion.InterfaceType
expected string
expected string
}{
{"聊天", "/v1/messages", conversion.InterfaceTypeChat, "/v1/messages"},
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
@@ -102,9 +103,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
a := NewAdapter()
tests := []struct {
name string
name string
interfaceType conversion.InterfaceType
expected bool
expected bool
}{
{"聊天", conversion.InterfaceTypeChat, true},
{"模型", conversion.InterfaceTypeModels, true},
@@ -141,8 +142,8 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
t.Run("解码嵌入请求", func(t *testing.T) {
_, err := a.DecodeEmbeddingRequest([]byte(`{}`))
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
@@ -150,24 +151,24 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.True(t, errors.As(err, &convErr))
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
t.Run("解码嵌入响应", func(t *testing.T) {
_, err := a.DecodeEmbeddingResponse([]byte(`{}`))
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
t.Run("编码嵌入响应", func(t *testing.T) {
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
}
@@ -178,8 +179,8 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
t.Run("解码重排序请求", func(t *testing.T) {
_, err := a.DecodeRerankRequest([]byte(`{}`))
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
@@ -187,24 +188,24 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
t.Run("解码重排序响应", func(t *testing.T) {
_, err := a.DecodeRerankResponse([]byte(`{}`))
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
t.Run("编码重排序响应", func(t *testing.T) {
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
}

View File

@@ -28,7 +28,10 @@ func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
var canonicalMsgs []canonical.CanonicalMessage
for _, msg := range req.Messages {
decoded := decodeMessage(msg)
decoded, err := decodeMessage(msg)
if err != nil {
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析消息内容失败").WithCause(err)
}
canonicalMsgs = append(canonicalMsgs, decoded...)
}
@@ -94,10 +97,13 @@ func decodeSystem(system any) any {
}
// decodeMessage 解码 Anthropic 消息
func decodeMessage(msg Message) []canonical.CanonicalMessage {
func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
switch msg.Role {
case "user":
blocks := decodeContentBlocks(msg.Content)
blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
var toolResults []canonical.ContentBlock
var others []canonical.ContentBlock
for _, b := range blocks {
@@ -117,58 +123,83 @@ func decodeMessage(msg Message) []canonical.CanonicalMessage {
if len(result) == 0 {
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
}
return result
return result, nil
case "assistant":
blocks := decodeContentBlocks(msg.Content)
blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
if len(blocks) == 0 {
blocks = append(blocks, canonical.NewTextBlock(""))
}
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
}
return nil
return nil, nil
}
// decodeContentBlocks 解码内容块列表
func decodeContentBlocks(content any) []canonical.ContentBlock {
func decodeContentBlocks(content any) ([]canonical.ContentBlock, error) {
switch v := content.(type) {
case string:
return []canonical.ContentBlock{canonical.NewTextBlock(v)}
return []canonical.ContentBlock{canonical.NewTextBlock(v)}, nil
case []any:
var blocks []canonical.ContentBlock
for _, item := range v {
if m, ok := item.(map[string]any); ok {
block := decodeSingleContentBlock(m)
block, err := decodeSingleContentBlock(m)
if err != nil {
return nil, err
}
if block != nil {
blocks = append(blocks, *block)
}
}
}
if len(blocks) > 0 {
return blocks
return blocks, nil
}
return []canonical.ContentBlock{canonical.NewTextBlock("")}
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
case nil:
return []canonical.ContentBlock{canonical.NewTextBlock("")}
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
default:
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}, nil
}
}
// decodeSingleContentBlock 解码单个内容块
func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
t, _ := m["type"].(string)
func decodeSingleContentBlock(m map[string]any) (*canonical.ContentBlock, error) {
t, ok := m["type"].(string)
if !ok {
return nil, nil
}
switch t {
case "text":
text, _ := m["text"].(string)
return &canonical.ContentBlock{Type: "text", Text: text}
text, ok := m["text"].(string)
if !ok {
text = ""
}
return &canonical.ContentBlock{Type: "text", Text: text}, nil
case "tool_use":
id, _ := m["id"].(string)
name, _ := m["name"].(string)
input, _ := json.Marshal(m["input"])
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}
id, ok := m["id"].(string)
if !ok {
id = ""
}
name, ok := m["name"].(string)
if !ok {
name = ""
}
input, err := json.Marshal(m["input"])
if err != nil {
return nil, err
}
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}, nil
case "tool_result":
toolUseID, _ := m["tool_use_id"].(string)
toolUseID, ok := m["tool_use_id"].(string)
if !ok {
toolUseID = ""
}
isErr := false
if ie, ok := m["is_error"].(bool); ok {
isErr = ie
@@ -179,7 +210,11 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
case string:
content = json.RawMessage(fmt.Sprintf("%q", cv))
default:
content, _ = json.Marshal(cv)
encoded, err := json.Marshal(cv)
if err != nil {
return nil, err
}
content = encoded
}
} else {
content = json.RawMessage(`""`)
@@ -189,15 +224,18 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
ToolUseID: toolUseID,
Content: content,
IsError: &isErr,
}
}, nil
case "thinking":
thinking, _ := m["thinking"].(string)
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}
thinking, ok := m["thinking"].(string)
if !ok {
thinking = ""
}
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}, nil
case "redacted_thinking":
// 丢弃
return nil
return nil, nil
}
return nil
return nil, nil
}
// decodeTools 解码工具定义
@@ -232,7 +270,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny()
}
case map[string]any:
t, _ := v["type"].(string)
t, ok := v["type"].(string)
if !ok {
return nil
}
switch t {
case "auto":
return canonical.NewToolChoiceAuto()
@@ -241,7 +282,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
case "any":
return canonical.NewToolChoiceAny()
case "tool":
name, _ := v["name"].(string)
name, ok := v["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name)
}
}

View File

@@ -182,7 +182,7 @@ func encodeContentBlocks(blocks []canonical.ContentBlock) []map[string]any {
result = append(result, m)
case "tool_result":
m := map[string]any{
"type": "tool_result",
"type": "tool_result",
"tool_use_id": b.ToolUseID,
}
if b.Content != nil {
@@ -335,11 +335,11 @@ func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
}
result := map[string]any{
"id": resp.ID,
"type": "message",
"role": "assistant",
"model": resp.Model,
"content": blocks,
"id": resp.ID,
"type": "message",
"role": "assistant",
"model": resp.Model,
"content": blocks,
"stop_reason": sr,
"stop_sequence": nil,
"usage": usage,

View File

@@ -33,7 +33,8 @@ func TestEncodeRequest_Basic(t *testing.T) {
assert.Equal(t, true, result["stream"])
assert.Equal(t, float64(1024), result["max_tokens"])
msgs := result["messages"].([]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1)
}
@@ -55,17 +56,20 @@ func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
// tool 消息应被合并到相邻 user 消息
foundToolResult := false
for _, m := range msgs {
msgMap := m.(map[string]any)
msgMap, ok := m.(map[string]any)
require.True(t, ok)
if msgMap["role"] == "user" {
content, ok := msgMap["content"].([]any)
if ok {
for _, c := range content {
block := c.(map[string]any)
block, ok := c.(map[string]any)
require.True(t, ok)
if block["type"] == "tool_result" {
foundToolResult = true
}
@@ -93,8 +97,10 @@ func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any)
firstMsg := msgs[0].(map[string]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", firstMsg["role"])
}
@@ -140,9 +146,11 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "assistant", result["role"])
assert.Equal(t, "end_turn", result["stop_reason"])
content := result["content"].([]any)
content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1)
block := content[0].(map[string]any)
block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", block["type"])
assert.Equal(t, "你好", block["text"])
}
@@ -160,10 +168,12 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
data := result["data"].([]any)
data, ok := result["data"].([]any)
require.True(t, ok)
assert.Len(t, data, 1)
model := data[0].(map[string]any)
model, ok := data[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "claude-3-opus", model["id"])
// created 应为 RFC3339 格式
createdAt, ok := model["created_at"].(string)
@@ -280,11 +290,14 @@ func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1)
userMsg := msgs[0].(map[string]any)
userMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", userMsg["role"])
content := userMsg["content"].([]any)
content, ok := userMsg["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 2)
}
@@ -321,7 +334,8 @@ func TestEncodeResponse_ReasoningTokens(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any)
usage, ok := result["usage"].(map[string]any)
require.True(t, ok)
_, hasReasoning := usage["reasoning_tokens"]
assert.False(t, hasReasoning)
}
@@ -341,9 +355,11 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
content := result["content"].([]any)
content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1)
block := content[0].(map[string]any)
block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", block["type"])
assert.Equal(t, "tool_1", block["id"])
assert.Equal(t, "search", block["name"])

View File

@@ -28,7 +28,7 @@ func NewStreamDecoder() *StreamDecoder {
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
data := rawChunk
if len(d.utf8Remainder) > 0 {
data = append(d.utf8Remainder, rawChunk...)
data = append(append([]byte{}, d.utf8Remainder...), rawChunk...)
d.utf8Remainder = nil
}
@@ -50,9 +50,10 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
for _, line := range strings.Split(text, "\n") {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "event: ") {
switch {
case strings.HasPrefix(line, "event: "):
eventType = strings.TrimPrefix(line, "event: ")
} else if strings.HasPrefix(line, "data: ") {
case strings.HasPrefix(line, "data: "):
eventData = strings.TrimPrefix(line, "data: ")
if eventType != "" && eventData != "" {
chunkEvents := d.processEvent(eventType, []byte(eventData))
@@ -60,8 +61,8 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
}
eventType = ""
eventData = ""
} else if line == "" {
// SSE 事件分隔符
case line == "":
continue
}
}
@@ -135,7 +136,7 @@ func (d *StreamDecoder) processMessageStart(data []byte) []canonical.CanonicalSt
// processContentBlockStart 处理内容块开始事件
func (d *StreamDecoder) processContentBlockStart(data []byte) []canonical.CanonicalStreamEvent {
var raw struct {
Index int `json:"index"`
Index int `json:"index"`
ContentBlock struct {
Type string `json:"type"`
Text string `json:"text"`

View File

@@ -47,23 +47,23 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
checkValue string
}{
{
name: "text_delta",
deltaType: "text_delta",
deltaData: map[string]any{"type": "text_delta", "text": "你好"},
name: "text_delta",
deltaType: "text_delta",
deltaData: map[string]any{"type": "text_delta", "text": "你好"},
checkField: "text",
checkValue: "你好",
},
{
name: "input_json_delta",
deltaType: "input_json_delta",
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
name: "input_json_delta",
deltaType: "input_json_delta",
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
checkField: "partial_json",
checkValue: "{\"key\":",
},
{
name: "thinking_delta",
deltaType: "thinking_delta",
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
name: "thinking_delta",
deltaType: "thinking_delta",
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
checkField: "thinking",
checkValue: "思考中",
},
@@ -74,7 +74,7 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
payload := map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": tt.deltaData,
"delta": tt.deltaData,
}
raw := makeAnthropicEvent("content_block_delta", payload)
@@ -298,7 +298,7 @@ func TestStreamDecoder_WebSearchToolResult_Suppressed(t *testing.T) {
"type": "content_block_start",
"index": 3,
"content_block": map[string]any{
"type": "web_search_tool_result",
"type": "web_search_tool_result",
"tool_use_id": "search_1",
},
}
@@ -331,8 +331,8 @@ func TestStreamDecoder_CitationsDelta_Discarded(t *testing.T) {
"type": "content_block_delta",
"index": 0,
"delta": map[string]any{
"type": "citations_delta",
"citation": map[string]any{"title": "ref1"},
"type": "citations_delta",
"citation": map[string]any{"title": "ref1"},
},
}
raw := makeAnthropicEvent("content_block_delta", payload)
@@ -466,7 +466,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
},
}
deltaPayload1 := map[string]any{
"type": "message_delta",
"type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn"},
"usage": map[string]any{"output_tokens": 25},
}
@@ -478,7 +478,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
assert.Equal(t, 25, events[0].Usage.OutputTokens)
deltaPayload2 := map[string]any{
"type": "message_delta",
"type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn"},
"usage": map[string]any{"output_tokens": 30},
}

View File

@@ -80,7 +80,8 @@ func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
break
}
}
cb := payload["content_block"].(map[string]any)
cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", cb["type"])
}
@@ -107,7 +108,8 @@ func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
break
}
}
cb := payload["content_block"].(map[string]any)
cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", cb["type"])
assert.Equal(t, "toolu_1", cb["id"])
assert.Equal(t, "search", cb["name"])
@@ -131,7 +133,8 @@ func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
break
}
}
cb := payload["content_block"].(map[string]any)
cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "thinking", cb["type"])
}
@@ -173,7 +176,8 @@ func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
break
}
}
delta := payload["delta"].(map[string]any)
delta, okd := payload["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "end_turn", delta["stop_reason"])
}
@@ -199,7 +203,8 @@ func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
break
}
}
u := payload["usage"].(map[string]any)
u, oku := payload["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(88), u["output_tokens"])
}

View File

@@ -173,13 +173,15 @@ func TestDecodeMessage_UserWithOnlyToolResults(t *testing.T) {
}
func TestDecodeContentBlocks_Nil(t *testing.T) {
blocks := decodeContentBlocks(nil)
blocks, err := decodeContentBlocks(nil)
require.NoError(t, err)
assert.Len(t, blocks, 1)
assert.Equal(t, "", blocks[0].Text)
}
func TestDecodeContentBlocks_String(t *testing.T) {
blocks := decodeContentBlocks("hello")
blocks, err := decodeContentBlocks("hello")
require.NoError(t, err)
assert.Len(t, blocks, 1)
assert.Equal(t, "hello", blocks[0].Text)
}
@@ -217,8 +219,10 @@ func TestEncodeToolChoice(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := encodeToolChoice(tt.choice)
assert.Equal(t, tt.want["type"], result.(map[string]any)["type"])
assert.Equal(t, tt.want["name"], result.(map[string]any)["name"])
r, ok := result.(map[string]any)
require.True(t, ok)
assert.Equal(t, tt.want["type"], r["type"])
assert.Equal(t, tt.want["name"], r["name"])
})
}
}
@@ -315,12 +319,15 @@ func TestEncodeRequest_WithTools(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
tools := result["tools"].([]any)
tools, okt := result["tools"].([]any)
require.True(t, okt)
assert.Len(t, tools, 1)
tool := tools[0].(map[string]any)
tool, okt2 := tools[0].(map[string]any)
require.True(t, okt2)
assert.Equal(t, "search", tool["name"])
assert.Equal(t, "Search things", tool["description"])
tc := result["tool_choice"].(map[string]any)
tc, oktc := result["tool_choice"].(map[string]any)
require.True(t, oktc)
assert.Equal(t, "auto", tc["type"])
}
@@ -354,9 +361,9 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
StopReason: &sr,
Usage: canonical.CanonicalUsage{
InputTokens: 100,
OutputTokens: 50,
CacheReadTokens: &cacheRead,
InputTokens: 100,
OutputTokens: 50,
CacheReadTokens: &cacheRead,
CacheCreationTokens: &cacheCreation,
},
}
@@ -366,7 +373,8 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any)
usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["input_tokens"])
assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])

View File

@@ -6,22 +6,22 @@ import (
// MessagesRequest Anthropic Messages 请求
type MessagesRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
System any `json:"system,omitempty"`
MaxTokens int `json:"max_tokens"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Metadata *RequestMetadata `json:"metadata,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
OutputConfig *OutputConfig `json:"output_config,omitempty"`
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
Container any `json:"container,omitempty"`
Model string `json:"model"`
Messages []Message `json:"messages"`
System any `json:"system,omitempty"`
MaxTokens int `json:"max_tokens"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Metadata *RequestMetadata `json:"metadata,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
OutputConfig *OutputConfig `json:"output_config,omitempty"`
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
Container any `json:"container,omitempty"`
}
// RequestMetadata 请求元数据
@@ -122,8 +122,8 @@ type ContentBlock struct {
// ResponseUsage 响应用量
type ResponseUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"`
CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"`
}

View File

@@ -38,8 +38,8 @@ type CanonicalEmbeddingResponse struct {
// EmbeddingData 嵌入数据项
type EmbeddingData struct {
Index int `json:"index"`
Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串
Index int `json:"index"`
Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串
}
// EmbeddingUsage 嵌入用量

View File

@@ -18,17 +18,17 @@ const (
type DeltaType string
const (
DeltaTypeText DeltaType = "text_delta"
DeltaTypeInputJSON DeltaType = "input_json_delta"
DeltaTypeThinking DeltaType = "thinking_delta"
DeltaTypeText DeltaType = "text_delta"
DeltaTypeInputJSON DeltaType = "input_json_delta"
DeltaTypeThinking DeltaType = "thinking_delta"
)
// StreamDelta 流式增量联合体
type StreamDelta struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Type string `json:"type"`
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
}
// StreamContentBlock 流式内容块联合体
@@ -48,12 +48,12 @@ type CanonicalStreamEvent struct {
Message *StreamMessage `json:"message,omitempty"`
// ContentBlockStartEvent / ContentBlockDeltaEvent / ContentBlockStopEvent
Index *int `json:"index,omitempty"`
Index *int `json:"index,omitempty"`
ContentBlock *StreamContentBlock `json:"content_block,omitempty"`
Delta *StreamDelta `json:"delta,omitempty"`
Delta *StreamDelta `json:"delta,omitempty"`
// MessageDeltaEvent
StopReason *StopReason `json:"stop_reason,omitempty"`
StopReason *StopReason `json:"stop_reason,omitempty"`
Usage *CanonicalUsage `json:"usage,omitempty"`
// ErrorEvent

View File

@@ -40,8 +40,8 @@ type ContentBlock struct {
Text string `json:"text,omitempty"`
// ToolUseBlock
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input json.RawMessage `json:"input,omitempty"`
// ToolResultBlock
@@ -138,43 +138,43 @@ type ThinkingConfig struct {
// OutputFormat 输出格式联合体
type OutputFormat struct {
Type string `json:"type"`
Name string `json:"name,omitempty"`
Schema json.RawMessage `json:"schema,omitempty"`
Strict *bool `json:"strict,omitempty"`
Type string `json:"type"`
Name string `json:"name,omitempty"`
Schema json.RawMessage `json:"schema,omitempty"`
Strict *bool `json:"strict,omitempty"`
}
// CanonicalRequest 规范请求
type CanonicalRequest struct {
Model string `json:"model"`
System any `json:"system,omitempty"` // nil, string, or []SystemBlock
Model string `json:"model"`
System any `json:"system,omitempty"` // nil, string, or []SystemBlock
Messages []CanonicalMessage `json:"messages"`
Tools []CanonicalTool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Parameters RequestParameters `json:"parameters"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Stream bool `json:"stream"`
UserID string `json:"user_id,omitempty"`
OutputFormat *OutputFormat `json:"output_format,omitempty"`
ParallelToolUse *bool `json:"parallel_tool_use,omitempty"`
Tools []CanonicalTool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Parameters RequestParameters `json:"parameters"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Stream bool `json:"stream"`
UserID string `json:"user_id,omitempty"`
OutputFormat *OutputFormat `json:"output_format,omitempty"`
ParallelToolUse *bool `json:"parallel_tool_use,omitempty"`
}
// CanonicalUsage 规范用量
type CanonicalUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheReadTokens *int `json:"cache_read_tokens,omitempty"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheReadTokens *int `json:"cache_read_tokens,omitempty"`
CacheCreationTokens *int `json:"cache_creation_tokens,omitempty"`
ReasoningTokens *int `json:"reasoning_tokens,omitempty"`
ReasoningTokens *int `json:"reasoning_tokens,omitempty"`
}
// CanonicalResponse 规范响应
type CanonicalResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason *StopReason `json:"stop_reason,omitempty"`
Usage CanonicalUsage `json:"usage"`
ID string `json:"id"`
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason *StopReason `json:"stop_reason,omitempty"`
Usage CanonicalUsage `json:"usage"`
}
// GetSystemString 获取系统消息字符串

View File

@@ -10,9 +10,9 @@ import (
func TestGetSystemString(t *testing.T) {
tests := []struct {
name string
system any
want string
name string
system any
want string
}{
{"string", "hello", "hello"},
{"nil", nil, ""},
@@ -97,11 +97,11 @@ func TestCanonicalRequest_RoundTrip(t *testing.T) {
func TestCanonicalResponse_RoundTrip(t *testing.T) {
sr := StopReasonEndTurn
resp := &CanonicalResponse{
ID: "resp-1",
Model: "gpt-4",
Content: []ContentBlock{NewTextBlock("hello")},
ID: "resp-1",
Model: "gpt-4",
Content: []ContentBlock{NewTextBlock("hello")},
StopReason: &sr,
Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5},
Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5},
}
data, err := json.Marshal(resp)

View File

@@ -114,7 +114,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
}
interfaceType := clientAdapter.DetectInterfaceType(nativePath)
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType)
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
providerHeaders := providerAdapter.BuildHeaders(provider)
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
if err != nil {
@@ -122,7 +122,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
}
return &HTTPRequestSpec{
URL: provider.BaseURL + providerUrl,
URL: provider.BaseURL + providerURL,
Method: spec.Method,
Headers: providerHeaders,
Body: providerBody,
@@ -134,24 +134,21 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议时最小化改写 model 字段
if modelOverride != "" && len(spec.Body) > 0 {
adapter, err := e.registry.Get(clientProtocol)
if err != nil {
return &spec, nil
}
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
if err != nil {
adapter, getErr := e.registry.Get(clientProtocol)
if getErr == nil {
rewrittenBody, rewriteErr := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
if rewriteErr != nil {
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
zap.Error(err),
zap.Error(rewriteErr),
zap.String("interface", string(interfaceType)))
return &spec, nil
} else {
return &HTTPResponseSpec{
StatusCode: spec.StatusCode,
Headers: spec.Headers,
Body: rewrittenBody,
}, nil
}
}
return &HTTPResponseSpec{
StatusCode: spec.StatusCode,
Headers: spec.Headers,
Body: rewrittenBody,
}, nil
}
return &spec, nil
}
@@ -182,11 +179,10 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
if modelOverride != "" {
adapter, err := e.registry.Get(clientProtocol)
if err != nil {
return NewPassthroughStreamConverter(), nil
adapter, getErr := e.registry.Get(clientProtocol)
if getErr == nil {
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
}
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
}
return NewPassthroughStreamConverter(), nil
}
@@ -201,9 +197,9 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
}
ctx := ConversionContext{
ConversionID: uuid.New().String(),
InterfaceType: InterfaceTypeChat,
Timestamp: time.Now(),
ConversionID: uuid.New().String(),
InterfaceType: InterfaceTypeChat,
Timestamp: time.Now(),
}
return NewCanonicalStreamConverterWithMiddleware(
@@ -306,7 +302,7 @@ func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapte
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
models, err := providerAdapter.DecodeModelsResponse(body)
if err != nil {
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.Error(err))
return body, nil
}
encoded, err := clientAdapter.EncodeModelsResponse(models)
@@ -320,12 +316,12 @@ func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdap
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
info, err := providerAdapter.DecodeModelInfoResponse(body)
if err != nil {
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil
}
encoded, err := clientAdapter.EncodeModelInfoResponse(info)
if err != nil {
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil
}
return encoded, nil
@@ -334,7 +330,7 @@ func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeEmbeddingRequest(body)
if err != nil {
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error()))
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.Error(err))
return body, nil
}
return providerAdapter.EncodeEmbeddingRequest(req, provider)
@@ -343,7 +339,7 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
if err != nil {
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.Error(err))
return body, nil
}
if modelOverride != "" {
@@ -355,21 +351,22 @@ func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeRerankRequest(body)
if err != nil {
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error()))
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.Error(err))
return body, nil
}
return providerAdapter.EncodeRerankRequest(req, provider)
}
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeRerankResponse(body)
if err != nil {
return body, nil
resp, decodeErr := providerAdapter.DecodeRerankResponse(body)
if decodeErr == nil {
if modelOverride != "" {
resp.Model = modelOverride
}
return clientAdapter.EncodeRerankResponse(resp)
}
if modelOverride != "" {
resp.Model = modelOverride
}
return clientAdapter.EncodeRerankResponse(resp)
return body, nil
}
// DetectInterfaceType 检测接口类型
@@ -391,8 +388,12 @@ func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol stri
"type": "internal_error",
},
}
body, _ := json.Marshal(fallback)
return body, 500, nil
body, marshalErr := json.Marshal(fallback)
if marshalErr == nil {
return body, 500, nil
}
return []byte(`{"error":{"message":"internal error","type":"internal_error"}}`), 500, nil
}
body, statusCode := adapter.EncodeError(err)
return body, statusCode, nil

View File

@@ -38,8 +38,8 @@ func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
}
}
func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName }
func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" }
func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName }
func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" }
func (m *mockProtocolAdapter) SupportsPassthrough() bool { return m.passthrough }
func (m *mockProtocolAdapter) DetectInterfaceType(nativePath string) InterfaceType {
@@ -190,14 +190,16 @@ func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel str
// noopStreamDecoder 空流式解码器
type noopStreamDecoder struct{}
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil }
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
return nil
}
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
// noopStreamEncoder 空流式编码器
type noopStreamEncoder struct{}
func (e *noopStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { return nil }
func (e *noopStreamEncoder) Flush() [][]byte { return nil }
func (e *noopStreamEncoder) Flush() [][]byte { return nil }
// ============ 测试用例 ============
@@ -615,6 +617,7 @@ func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.Canonical
}
return nil
}
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
if d.flushFn != nil {
return d.flushFn()
@@ -634,6 +637,7 @@ func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEve
}
return nil
}
func (e *engineTestStreamEncoder) Flush() [][]byte {
if e.flushFn != nil {
return e.flushFn()

View File

@@ -6,17 +6,17 @@ import "fmt"
type ErrorCode string
const (
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD"
ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE"
ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE"
ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR"
ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR"
ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR"
ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR"
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD"
ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE"
ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE"
ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR"
ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR"
ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR"
ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR"
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
)
// ConversionError 协议转换错误

View File

@@ -4,10 +4,10 @@ package conversion
type InterfaceType string
const (
InterfaceTypeChat InterfaceType = "CHAT"
InterfaceTypeModels InterfaceType = "MODELS"
InterfaceTypeModelInfo InterfaceType = "MODEL_INFO"
InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS"
InterfaceTypeRerank InterfaceType = "RERANK"
InterfaceTypeChat InterfaceType = "CHAT"
InterfaceTypeModels InterfaceType = "MODELS"
InterfaceTypeModelInfo InterfaceType = "MODEL_INFO"
InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS"
InterfaceTypeRerank InterfaceType = "RERANK"
InterfaceTypePassthrough InterfaceType = "PASSTHROUGH"
)

View File

@@ -138,7 +138,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Code: string(err.Code),
},
}
body, _ := json.Marshal(errMsg)
body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"error":{"message":"internal error","type":"internal_error","code":"INTERNAL_ERROR"}}`), statusCode
}
return body, statusCode
}
@@ -248,7 +251,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
return "", nil, err
}
rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel)
encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m)
}
return current, rewriteFunc, nil
@@ -282,12 +289,20 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel)
encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m)
case conversion.InterfaceTypeRerank:
// Rerank 响应:存在 model 字段则改写,不存在则不添加
if _, exists := m["model"]; exists {
m["model"], _ = json.Marshal(newModel)
encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
}
return json.Marshal(m)
default:

View File

@@ -48,10 +48,10 @@ func TestAdapter_BuildUrl(t *testing.T) {
a := NewAdapter()
tests := []struct {
name string
nativePath string
name string
nativePath string
interfaceType conversion.InterfaceType
expected string
expected string
}{
{"聊天", "/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
{"模型", "/models", conversion.InterfaceTypeModels, "/models"},
@@ -92,9 +92,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
a := NewAdapter()
tests := []struct {
name string
name string
interfaceType conversion.InterfaceType
expected bool
expected bool
}{
{"聊天", conversion.InterfaceTypeChat, true},
{"模型", conversion.InterfaceTypeModels, true},

View File

@@ -215,10 +215,16 @@ func decodeUserContent(content any) []canonical.ContentBlock {
var blocks []canonical.ContentBlock
for _, item := range v {
if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string)
t, ok := m["type"].(string)
if !ok {
continue
}
switch t {
case "text":
text, _ := m["text"].(string)
text, ok := m["text"].(string)
if !ok {
text = ""
}
blocks = append(blocks, canonical.NewTextBlock(text))
case "image_url":
blocks = append(blocks, canonical.ContentBlock{Type: "image"})
@@ -242,9 +248,9 @@ func decodeUserContent(content any) []canonical.ContentBlock {
// contentPart 内容部分
type contentPart struct {
Type string
Text string
Refusal string
Type string
Text string
Refusal string
}
// decodeContentParts 解码内容部分
@@ -256,13 +262,22 @@ func decodeContentParts(content any) []contentPart {
var result []contentPart
for _, item := range parts {
if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string)
t, ok := m["type"].(string)
if !ok {
continue
}
switch t {
case "text":
text, _ := m["text"].(string)
text, ok := m["text"].(string)
if !ok {
text = ""
}
result = append(result, contentPart{Type: "text", Text: text})
case "refusal":
refusal, _ := m["refusal"].(string)
refusal, ok := m["refusal"].(string)
if !ok {
refusal = ""
}
result = append(result, contentPart{Type: "refusal", Refusal: refusal})
}
}
@@ -307,21 +322,33 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny()
}
case map[string]any:
t, _ := v["type"].(string)
t, ok := v["type"].(string)
if !ok {
return nil
}
switch t {
case "function":
if fn, ok := v["function"].(map[string]any); ok {
name, _ := fn["name"].(string)
name, ok := fn["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name)
}
case "custom":
if custom, ok := v["custom"].(map[string]any); ok {
name, _ := custom["name"].(string)
name, ok := custom["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name)
}
case "allowed_tools":
if at, ok := v["allowed_tools"].(map[string]any); ok {
mode, _ := at["mode"].(string)
mode, ok := at["mode"].(string)
if !ok {
mode = ""
}
if mode == "required" {
return canonical.NewToolChoiceAny()
}
@@ -443,7 +470,7 @@ func decodeDeprecatedFields(req *ChatCompletionRequest) {
case map[string]any:
if name, ok := v["name"].(string); ok {
req.ToolChoice = map[string]any{
"type": "function",
"type": "function",
"function": map[string]any{"name": name},
}
}

View File

@@ -450,7 +450,7 @@ func encodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte
"object": "list",
"data": data,
"model": resp.Model,
"usage": resp.Usage,
"usage": resp.Usage,
})
}

View File

@@ -45,9 +45,11 @@ func TestEncodeRequest_SystemInjection(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 2)
firstMsg := msgs[0].(map[string]any)
firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "system", firstMsg["role"])
assert.Equal(t, "你是助手", firstMsg["content"])
}
@@ -72,12 +74,15 @@ func TestEncodeRequest_ToolCalls(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any)
assistantMsg := msgs[0].(map[string]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
assistantMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
toolCalls, ok := assistantMsg["tool_calls"].([]any)
require.True(t, ok)
assert.Len(t, toolCalls, 1)
tc := toolCalls[0].(map[string]any)
tc, ok := toolCalls[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "call_1", tc["id"])
}
@@ -100,11 +105,11 @@ func TestEncodeRequest_Thinking(t *testing.T) {
func TestEncodeResponse_Basic(t *testing.T) {
sr := canonical.StopReasonEndTurn
resp := &canonical.CanonicalResponse{
ID: "resp-1",
Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
ID: "resp-1",
Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
StopReason: &sr,
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
}
body, err := encodeResponse(resp)
@@ -115,9 +120,12 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "resp-1", result["id"])
assert.Equal(t, "chat.completion", result["object"])
choices := result["choices"].([]any)
choice := choices[0].(map[string]any)
msg := choice["message"].(map[string]any)
choices, ok := result["choices"].([]any)
require.True(t, ok)
choice, ok := choices[0].(map[string]any)
require.True(t, ok)
msg, ok := choice["message"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "你好", msg["content"])
assert.Equal(t, "stop", choice["finish_reason"])
}
@@ -126,9 +134,9 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
sr := canonical.StopReasonToolUse
input := json.RawMessage(`{"q":"test"}`)
resp := &canonical.CanonicalResponse{
ID: "resp-2",
Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
ID: "resp-2",
Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
StopReason: &sr,
}
@@ -137,8 +145,12 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any)
choices, okc := result["choices"].([]any)
require.True(t, okc)
msgMap, okm := choices[0].(map[string]any)
require.True(t, okm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
tcs, ok := msg["tool_calls"].([]any)
require.True(t, ok)
assert.Len(t, tcs, 1)
@@ -158,7 +170,8 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "list", result["object"])
data := result["data"].([]any)
data, okd := result["data"].([]any)
require.True(t, okd)
assert.Len(t, data, 2)
}
@@ -317,8 +330,12 @@ func TestEncodeResponse_Thinking(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any)
choices, okch := result["choices"].([]any)
require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
assert.Equal(t, "回答", msg["content"])
assert.Equal(t, "思考过程", msg["reasoning_content"])
}

View File

@@ -18,9 +18,9 @@ func TestStreamDecoder_BasicText(t *testing.T) {
d := NewStreamDecoder()
chunk := map[string]any{
"id": "chatcmpl-1",
"object": "chat.completion.chunk",
"model": "gpt-4",
"id": "chatcmpl-1",
"object": "chat.completion.chunk",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -56,8 +56,8 @@ func TestStreamDecoder_ToolCalls(t *testing.T) {
idx := 0
chunk := map[string]any{
"id": "chatcmpl-1",
"model": "gpt-4",
"id": "chatcmpl-1",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -98,8 +98,8 @@ func TestStreamDecoder_Thinking(t *testing.T) {
d := NewStreamDecoder()
chunk := map[string]any{
"id": "chatcmpl-1",
"model": "gpt-4",
"id": "chatcmpl-1",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -127,8 +127,8 @@ func TestStreamDecoder_FinishReason(t *testing.T) {
d := NewStreamDecoder()
chunk := map[string]any{
"id": "chatcmpl-1",
"model": "gpt-4",
"id": "chatcmpl-1",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -161,8 +161,8 @@ func TestStreamDecoder_DoneSignal(t *testing.T) {
// 先发送一个文本 chunk
chunk := map[string]any{
"id": "chatcmpl-1",
"model": "gpt-4",
"id": "chatcmpl-1",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -190,8 +190,8 @@ func TestStreamDecoder_RefusalReuse(t *testing.T) {
// 连续两个 refusal delta chunk
for _, text := range []string{"拒绝", "原因"} {
chunk := map[string]any{
"id": "chatcmpl-1",
"model": "gpt-4",
"id": "chatcmpl-1",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -250,8 +250,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
idx0 := 0
chunk1 := map[string]any{
"id": "chatcmpl-mt",
"model": "gpt-4",
"id": "chatcmpl-mt",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -274,8 +274,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
idx1 := 1
chunk2 := map[string]any{
"id": "chatcmpl-mt",
"model": "gpt-4",
"id": "chatcmpl-mt",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -322,8 +322,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
d := NewStreamDecoder()
chunk1 := map[string]any{
"id": "chatcmpl-multi",
"model": "gpt-4",
"id": "chatcmpl-multi",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -332,8 +332,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
},
}
chunk2 := map[string]any{
"id": "chatcmpl-multi",
"model": "gpt-4",
"id": "chatcmpl-multi",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -358,8 +358,8 @@ func TestStreamDecoder_UTF8Truncation(t *testing.T) {
d := NewStreamDecoder()
chunk := map[string]any{
"id": "chatcmpl-utf8",
"model": "gpt-4",
"id": "chatcmpl-utf8",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -390,8 +390,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
idx := 0
chunk1 := map[string]any{
"id": "chatcmpl-tc",
"model": "gpt-4",
"id": "chatcmpl-tc",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -412,8 +412,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
},
}
chunk2 := map[string]any{
"id": "chatcmpl-tc",
"model": "gpt-4",
"id": "chatcmpl-tc",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,

View File

@@ -10,9 +10,9 @@ import (
// StreamEncoder OpenAI 流式编码器
type StreamEncoder struct {
bufferedStart *canonical.CanonicalStreamEvent
toolCallIndexMap map[string]int
nextToolCallIndex int
bufferedStart *canonical.CanonicalStreamEvent
toolCallIndexMap map[string]int
nextToolCallIndex int
}
// NewStreamEncoder 创建 OpenAI 流式编码器
@@ -195,8 +195,8 @@ func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent)
func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte {
chunk := map[string]any{
"choices": []map[string]any{{
"index": 0,
"delta": delta,
"index": 0,
"delta": delta,
}},
}
return e.marshalChunk(chunk)

View File

@@ -27,8 +27,12 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
data := strings.TrimPrefix(s, "data: ")
data = strings.TrimRight(data, "\n")
require.NoError(t, json.Unmarshal([]byte(data), &payload))
choices := payload["choices"].([]any)
delta := choices[0].(map[string]any)["delta"].(map[string]any)
choices, okch := payload["choices"].([]any)
require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
delta, okd := msgMap["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "assistant", delta["role"])
}

View File

@@ -177,7 +177,8 @@ func TestEncodeRerankResponse(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "rerank-1", result["model"])
results := result["results"].([]any)
results, okr := result["results"].([]any)
require.True(t, okr)
assert.Len(t, results, 1)
}
@@ -356,9 +357,9 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
reasoning := 20
sr := canonical.StopReasonEndTurn
resp := &canonical.CanonicalResponse{
ID: "r1",
Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
ID: "r1",
Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
StopReason: &sr,
Usage: canonical.CanonicalUsage{
InputTokens: 100,
@@ -373,7 +374,8 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any)
usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["prompt_tokens"])
ptd, ok := usage["prompt_tokens_details"].(map[string]any)
require.True(t, ok)
@@ -412,8 +414,10 @@ func TestEncodeResponse_StopReasons(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any)
choice := choices[0].(map[string]any)
choices, okch := result["choices"].([]any)
require.True(t, okch)
choice, okc := choices[0].(map[string]any)
require.True(t, okc)
assert.Equal(t, tt.want, choice["finish_reason"])
})
}

View File

@@ -4,42 +4,42 @@ import "encoding/json"
// ChatCompletionRequest OpenAI Chat Completion 请求
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
Stop any `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
User string `json:"user,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
N *int `json:"n,omitempty"`
Seed *int `json:"seed,omitempty"`
Logprobs *bool `json:"logprobs,omitempty"`
TopLogprobs *int `json:"top_logprobs,omitempty"`
Model string `json:"model"`
Messages []Message `json:"messages"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
Stop any `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
User string `json:"user,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
N *int `json:"n,omitempty"`
Seed *int `json:"seed,omitempty"`
Logprobs *bool `json:"logprobs,omitempty"`
TopLogprobs *int `json:"top_logprobs,omitempty"`
// 已废弃字段
Functions []FunctionDef `json:"functions,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
Functions []FunctionDef `json:"functions,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
}
// Message OpenAI 消息
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
Name string `json:"name,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Refusal string `json:"refusal,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
Role string `json:"role"`
Content any `json:"content"`
Name string `json:"name,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Refusal string `json:"refusal,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
// 已废弃
FunctionCall *FunctionCallMsg `json:"function_call,omitempty"`
@@ -88,8 +88,8 @@ type FunctionDef struct {
// ResponseFormat OpenAI 响应格式
type ResponseFormat struct {
Type string `json:"type"`
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
Type string `json:"type"`
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
}
// JSONSchemaDef JSON Schema 定义
@@ -118,7 +118,7 @@ type ChatCompletionResponse struct {
// Choice OpenAI 选择项
type Choice struct {
Index int `json:"index"`
Index int `json:"index"`
Message *Message `json:"message,omitempty"`
Delta *Message `json:"delta,omitempty"`
FinishReason *string `json:"finish_reason"`
@@ -127,10 +127,10 @@ type Choice struct {
// Usage OpenAI 用量
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
}

View File

@@ -18,7 +18,7 @@ import (
func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
moduleLogger := pkglogger.WithModule(zapLogger, "database")
db, err := initDB(cfg, moduleLogger)
if err != nil {
return nil, fmt.Errorf("初始化数据库失败: %w", err)
@@ -61,7 +61,7 @@ func initDB(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error)
return gorm.Open(mysql.Open(dsn), gormConfig)
default:
dbDir := filepath.Dir(cfg.Path)
if err := os.MkdirAll(dbDir, 0755); err != nil {
if err := os.MkdirAll(dbDir, 0o755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
}
if zapLogger != nil {
@@ -95,7 +95,9 @@ func runMigrations(db *gorm.DB, driver string, zapLogger *zap.Logger) error {
zap.String("dir", migrationsSubDir))
}
goose.SetDialect(gooseDialect)
if err := goose.SetDialect(gooseDialect); err != nil {
return err
}
if err := goose.Up(sqlDB, migrationsDir); err != nil {
return err
}

View File

@@ -4,11 +4,11 @@ import (
"path/filepath"
"testing"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/config"
)
func TestInit_SQLite(t *testing.T) {

View File

@@ -13,4 +13,3 @@ type Provider struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}

View File

@@ -6,13 +6,13 @@ import (
"net/http/httptest"
"testing"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
)
func TestProviderHandler_CreateProvider_Success(t *testing.T) {
@@ -24,9 +24,9 @@ func TestProviderHandler_CreateProvider_Success(t *testing.T) {
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"id": "p1",
"name": "Test",
"api_key": "sk-test",
"id": "p1",
"name": "Test",
"api_key": "sk-test",
"base_url": "https://api.test.com",
})
w := httptest.NewRecorder()

View File

@@ -9,23 +9,22 @@ import (
"strings"
"testing"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"gorm.io/gorm"
"nex/backend/internal/domain"
appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks"
)
func init() {
gin.SetMode(gin.TestMode)
}
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

View File

@@ -20,7 +20,6 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
if id, ok := requestID.(string); ok {
requestIDStr = id
}
logger.Info("请求开始",
pkglogger.Method(c.Request.Method),
pkglogger.Path(path),

View File

@@ -4,13 +4,13 @@ import (
"errors"
"net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
)
// ModelHandler 模型管理处理器
@@ -58,16 +58,16 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
err := h.modelService.Create(model)
if err != nil {
if err == appErrors.ErrProviderNotFound {
if errors.Is(err, appErrors.ErrProviderNotFound) {
c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在",
})
return
}
if err == appErrors.ErrDuplicateModel {
if errors.Is(err, appErrors.ErrDuplicateModel) {
c.JSON(http.StatusConflict, gin.H{
"error": "同一供应商下模型名称已存在",
"code": appErrors.ErrDuplicateModel.Code,
"error": "同一供应商下模型名称已存在",
"code": appErrors.ErrDuplicateModel.Code,
})
return
}
@@ -101,7 +101,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
model, err := h.modelService.Get(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到",
})
@@ -166,7 +166,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
err := h.modelService.Delete(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到",
})

View File

@@ -4,13 +4,13 @@ import (
"errors"
"net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
)
// ProviderHandler 供应商管理处理器
@@ -55,7 +55,7 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
err := h.providerService.Create(provider)
if err != nil {
if err == appErrors.ErrInvalidProviderID {
if errors.Is(err, appErrors.ErrInvalidProviderID) {
c.JSON(http.StatusBadRequest, gin.H{
"error": appErrors.ErrInvalidProviderID.Message,
"code": appErrors.ErrInvalidProviderID.Code,
@@ -86,7 +86,7 @@ func (h *ProviderHandler) GetProvider(c *gin.Context) {
provider, err := h.providerService.Get(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到",
})
@@ -113,7 +113,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
err := h.providerService.Update(id, req)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到",
})
@@ -145,7 +145,7 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
err := h.providerService.Delete(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到",
})

View File

@@ -3,30 +3,32 @@ package handler
import (
"bufio"
"encoding/json"
"errors"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
"nex/backend/internal/service"
"nex/backend/pkg/modelid"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger"
)
// ProxyHandler 统一代理处理器
type ProxyHandler struct {
engine *conversion.ConversionEngine
client provider.ProviderClient
routingService service.RoutingService
engine *conversion.ConversionEngine
client provider.ProviderClient
routingService service.RoutingService
providerService service.ProviderService
statsService service.StatsService
logger *zap.Logger
statsService service.StatsService
logger *zap.Logger
}
// NewProxyHandler 创建统一代理处理器
@@ -138,7 +140,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
targetProvider := conversion.NewTargetProvider(
routeResult.Provider.BaseURL,
routeResult.Provider.APIKey,
routeResult.Model.ModelName, // 上游模型名,用于请求改写
routeResult.Model.ModelName, // 上游模型名,用于请求改写
)
// 判断是否流式
@@ -159,7 +161,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
h.logger.Error("转换请求失败", zap.String("error", err.Error()))
h.logger.Error("转换请求失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol)
return
}
@@ -167,7 +169,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 发送请求
resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil {
h.logger.Error("发送请求失败", zap.String("error", err.Error()))
h.logger.Error("发送请求失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol)
return
}
@@ -175,7 +177,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 转换响应,传入 modelOverride跨协议场景覆写 model 字段)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
if err != nil {
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
h.logger.Error("转换响应失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol)
return
}
@@ -191,7 +193,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}()
}
@@ -226,34 +228,50 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
for event := range eventChan {
if event.Error != nil {
h.logger.Error("流读取错误", zap.String("error", event.Error.Error()))
h.logger.Error("流读取错误", zap.Error(event.Error))
break
}
if event.Done {
// flush 转换器
chunks := streamConverter.Flush()
for _, chunk := range chunks {
writer.Write(chunk)
writer.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("流式响应写回失败", zap.Error(err))
}
break
}
chunks := streamConverter.ProcessChunk(event.Data)
for _, chunk := range chunks {
writer.Write(chunk)
writer.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("流式响应写回失败", zap.Error(err))
break
}
}
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}()
}
func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error {
for _, chunk := range chunks {
if _, err := writer.Write(chunk); err != nil {
return err
}
if err := writer.Flush(); err != nil {
return err
}
}
return nil
}
// isStreamRequest 判断是否流式请求
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol)
ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol)
if err != nil {
return false
}
if ifaceType != conversion.InterfaceTypeChat {
return false
}
@@ -272,7 +290,7 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
// 从数据库查询所有启用的模型
models, err := h.providerService.ListEnabledModels()
if err != nil {
h.logger.Error("查询启用模型失败", zap.String("error", err.Error()))
h.logger.Error("查询启用模型失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
return
}
@@ -294,7 +312,7 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
// 使用 adapter 编码返回
body, err := adapter.EncodeModelsResponse(modelList)
if err != nil {
h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error()))
h.logger.Error("编码 Models 响应失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
return
}
@@ -342,8 +360,13 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
// writeConversionError 写入转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
if convErr, ok := err.(*conversion.ConversionError); ok {
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol)
var convErr *conversion.ConversionError
if errors.As(err, &convErr) {
body, statusCode, encodeErr := h.engine.EncodeError(convErr, clientProtocol)
if encodeErr != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": encodeErr.Error()})
return
}
c.Data(statusCode, "application/json", body)
return
}

View File

@@ -8,27 +8,26 @@ import (
"net/http/httptest"
"testing"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks"
)
func init() {
gin.SetMode(gin.TestMode)
}
func setupProxyEngine(t *testing.T) *conversion.ConversionEngine {
t.Helper()
registry := conversion.NewMemoryRegistry()
@@ -844,7 +843,8 @@ func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
require.True(t, ok)
assert.Len(t, data, 2)
first := data[0].(map[string]interface{})
first, ok2 := data[0].(map[string]interface{})
require.True(t, ok2)
assert.Equal(t, "openai/gpt-4", first["id"])
}
@@ -918,7 +918,7 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
var req map[string]interface{}
json.Unmarshal(spec.Body, &req)
require.NoError(t, json.Unmarshal(spec.Body, &req))
assert.Equal(t, "gpt-4", req["model"])
return &conversion.HTTPResponseSpec{

View File

@@ -5,9 +5,9 @@ import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"nex/backend/internal/service"
"github.com/gin-gonic/gin"
)
// StatsHandler 统计处理器

View File

@@ -51,6 +51,7 @@ type Client struct {
}
// ProviderClient 供应商客户端接口
//
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
type ProviderClient interface {
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
@@ -141,7 +142,10 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
cancel()
errBody, _ := io.ReadAll(resp.Body)
errBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, fmt.Errorf("供应商返回错误: HTTP %d读取错误响应失败: %w", resp.StatusCode, readErr)
}
if len(errBody) > 0 {
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody))
}
@@ -184,7 +188,7 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
if err != nil {
if err != io.EOF {
if isNetworkError(err) {
c.logger.Error("流网络错误", zap.String("error", err.Error()))
c.logger.Error("流网络错误", zap.Error(err))
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
} else {
c.logger.Error("流读取错误", zap.Error(err))

View File

@@ -41,7 +41,8 @@ func TestClient_Send_Success(t *testing.T) {
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
_, err := w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
require.NoError(t, err)
}))
defer server.Close()
@@ -65,7 +66,8 @@ func TestClient_Send_Success(t *testing.T) {
func TestClient_Send_ErrorResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
_, err := w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
require.NoError(t, err)
}))
defer server.Close()
@@ -140,12 +142,15 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
_, err := w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush()
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
_, err = w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n"))
_, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(50 * time.Millisecond)
}))
@@ -165,11 +170,12 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
var dataEvents [][]byte
var doneEvents int
for event := range eventChan {
if event.Done {
switch {
case event.Done:
doneEvents++
} else if event.Error != nil {
case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error)
} else {
default:
dataEvents = append(dataEvents, event.Data)
}
}
@@ -215,7 +221,8 @@ func TestClient_Send_EmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"result":"ok"}`))
_, err := w.Write([]byte(`{"result":"ok"}`))
require.NoError(t, err)
}))
defer server.Close()
@@ -238,10 +245,12 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(100 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n"))
_, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(100 * time.Millisecond)
}))
@@ -261,11 +270,12 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
var dataCount int
var doneCount int
for event := range eventChan {
if event.Done {
switch {
case event.Done:
doneCount++
} else if event.Error != nil {
case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error)
} else {
default:
dataCount++
}
}
@@ -279,10 +289,12 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n"))
_, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(50 * time.Millisecond)
}))
@@ -364,13 +376,14 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(50 * time.Millisecond)
if hijacker, ok := w.(http.Hijacker); ok {
conn, _, _ := hijacker.Hijack()
if conn != nil {
conn.Close()
require.NoError(t, conn.Close())
}
}
}))

View File

@@ -3,10 +3,11 @@ package repository
import (
"time"
"gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/domain"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
)

View File

@@ -3,13 +3,13 @@ package repository
import (
"testing"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
testHelpers "nex/backend/tests"
"nex/backend/internal/domain"
)
func setupTestDB(t *testing.T) *gorm.DB {

View File

@@ -3,11 +3,11 @@ package repository
import (
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"nex/backend/internal/config"
"nex/backend/internal/domain"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type statsRepository struct {
@@ -19,8 +19,8 @@ func NewStatsRepository(db *gorm.DB) StatsRepository {
}
func (r *statsRepository) Record(providerID, modelName string) error {
today := time.Now().Format("2006-01-02")
todayTime, _ := time.Parse("2006-01-02", today)
now := time.Now()
todayTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
stats := config.UsageStats{
ProviderID: providerID,

View File

@@ -1,11 +1,15 @@
package service
import (
"github.com/google/uuid"
appErrors "nex/backend/pkg/errors"
"errors"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"github.com/google/uuid"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
)
type modelService struct {
@@ -108,7 +112,11 @@ func (s *modelService) Delete(id string) error {
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil {
return nil // 未找到,不重复
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil // 未找到,不重复
}
return err
}
if excludeID != "" && existing.ID == excludeID {
return nil // 排除自身

View File

@@ -3,10 +3,10 @@ package service
import (
"strings"
"nex/backend/pkg/modelid"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"nex/backend/pkg/modelid"
appErrors "nex/backend/pkg/errors"
)

View File

@@ -4,10 +4,11 @@ import (
"strings"
"sync"
"go.uber.org/zap"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger"
)
@@ -34,7 +35,9 @@ func NewRoutingCache(
func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil
if provider, ok := v.(*domain.Provider); ok {
return provider, nil
}
}
provider, err := c.providerRepo.GetByID(id)
@@ -43,7 +46,9 @@ func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
}
if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil
if provider, ok := v.(*domain.Provider); ok {
return provider, nil
}
}
c.providers.Store(id, provider)
@@ -54,7 +59,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
key := providerID + "/" + modelName
if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil
if model, ok := v.(*domain.Model); ok {
return model, nil
}
}
model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName)
@@ -63,7 +70,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
}
if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil
if model, ok := v.(*domain.Model); ok {
return model, nil
}
}
c.models.Store(key, model)
@@ -97,7 +106,12 @@ func (c *RoutingCache) invalidateModelsByProvider(providerID string) {
prefix := providerID + "/"
count := 0
c.models.Range(func(key, value interface{}) bool {
if strings.HasPrefix(key.(string), prefix) {
keyStr, ok := key.(string)
if !ok {
return true
}
if strings.HasPrefix(keyStr, prefix) {
c.models.Delete(key)
count++
}

View File

@@ -5,11 +5,11 @@ import (
"sync"
"testing"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/domain"
)
type mockModelRepo struct {
@@ -189,7 +189,8 @@ func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
var openaiCount, anthropicCount int
cache.models.Range(func(key, value interface{}) bool {
if key.(string) == "anthropic/claude" {
keyStr, ok := key.(string)
if ok && keyStr == "anthropic/claude" {
anthropicCount++
}
return true

View File

@@ -1,9 +1,8 @@
package service
import (
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
appErrors "nex/backend/pkg/errors"
)
type routingService struct {

View File

@@ -3,12 +3,12 @@ package service
import (
"testing"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
)
func TestProviderService_Update(t *testing.T) {
@@ -133,7 +133,9 @@ func TestStatsService_Aggregate_Default(t *testing.T) {
totalCount := 0
for _, r := range result {
totalCount += r["request_count"].(int)
count, ok := r["request_count"].(int)
require.True(t, ok)
totalCount += count
}
assert.Equal(t, 15, totalCount)
}

View File

@@ -5,6 +5,9 @@ import (
"testing"
"time"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -13,8 +16,6 @@ import (
testHelpers "nex/backend/tests"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
appErrors "nex/backend/pkg/errors"
)
@@ -134,7 +135,7 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
@@ -148,7 +149,7 @@ func TestProviderService_Create_InvalidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
@@ -160,7 +161,7 @@ func TestProviderService_Create_ValidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
@@ -176,7 +177,7 @@ func TestModelService_Update_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
@@ -202,7 +203,7 @@ func TestModelService_Update_ModelNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
err := svc.Update("nonexistent-id", map[string]interface{}{
@@ -215,7 +216,7 @@ func TestModelService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
@@ -241,7 +242,7 @@ func TestProviderService_Update_ImmutableID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
@@ -259,7 +260,7 @@ func TestProviderService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
@@ -318,7 +319,8 @@ func TestStatsService_Aggregate_ByModel(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, zap.NewNop()); svc := NewStatsService(statsRepo, buffer)
buffer := NewStatsBuffer(statsRepo, zap.NewNop())
svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "model")
@@ -379,7 +381,8 @@ func TestStatsService_Aggregate_ByDate(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, zap.NewNop()); svc := NewStatsService(statsRepo, buffer)
buffer := NewStatsBuffer(statsRepo, zap.NewNop())
svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "date")
@@ -448,7 +451,7 @@ func TestProviderService_List_APIKeyNotMasked(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
@@ -474,7 +477,7 @@ func TestModelService_ConcurrentCreate(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))

View File

@@ -6,9 +6,10 @@ import (
"sync/atomic"
"time"
"nex/backend/internal/repository"
"go.uber.org/zap"
"nex/backend/internal/repository"
pkglogger "nex/backend/pkg/logger"
)
@@ -67,13 +68,21 @@ func (b *StatsBuffer) Increment(providerID, modelName string) {
var counter *int64
if v, ok := b.counters.Load(key); ok {
counter = v.(*int64)
if existing, ok := v.(*int64); ok {
counter = existing
} else {
return
}
} else {
val := int64(0)
counter = &val
actual, loaded := b.counters.LoadOrStore(key, counter)
if loaded {
counter = actual.(*int64)
existing, ok := actual.(*int64)
if !ok {
return
}
counter = existing
}
}
@@ -117,13 +126,20 @@ func (b *StatsBuffer) flush() {
var entries []statEntry
b.counters.Range(func(key, value interface{}) bool {
keyStr := key.(string)
keyStr, ok := key.(string)
if !ok {
return true
}
parts := strings.Split(keyStr, "/")
if len(parts) != 3 {
return true
}
counter := value.(*int64)
counter, ok := value.(*int64)
if !ok {
return true
}
count := atomic.SwapInt64(counter, 0)
if count > 0 {
@@ -143,8 +159,17 @@ func (b *StatsBuffer) flush() {
success := 0
for _, entry := range entries {
date, _ := time.Parse("2006-01-02", entry.date)
err := b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
date, err := time.Parse("2006-01-02", entry.date)
if err != nil {
b.logger.Error("解析统计日期失败",
zap.String("provider_id", entry.providerID),
zap.String("model_name", entry.modelName),
zap.String("date", entry.date),
zap.Error(err))
continue
}
err = b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
if err != nil {
b.logger.Error("批量更新统计失败",
zap.String("provider_id", entry.providerID),
@@ -154,8 +179,10 @@ func (b *StatsBuffer) flush() {
key := entry.providerID + "/" + entry.modelName + "/" + entry.date
if v, ok := b.counters.Load(key); ok {
counter := v.(*int64)
atomic.AddInt64(counter, entry.count)
counter, ok := v.(*int64)
if ok {
atomic.AddInt64(counter, entry.count)
}
}
} else {
success++

View File

@@ -7,10 +7,10 @@ import (
"testing"
"time"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"nex/backend/internal/domain"
)
type mockStatsRepo struct {
@@ -58,8 +58,10 @@ func TestStatsBuffer_Increment(t *testing.T) {
var count int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
count += atomic.LoadInt64(counter)
counter, ok := value.(*int64)
if ok {
count += atomic.LoadInt64(counter)
}
return true
})
assert.Equal(t, int64(3), count)
@@ -82,8 +84,10 @@ func TestStatsBuffer_ConcurrentIncrement(t *testing.T) {
var count int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
count = atomic.LoadInt64(counter)
counter, ok := value.(*int64)
if ok {
count = atomic.LoadInt64(counter)
}
return true
})
assert.Equal(t, int64(100), count)
@@ -161,8 +165,10 @@ func TestStatsBuffer_SwapInt64(t *testing.T) {
var beforeCount int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
beforeCount = atomic.LoadInt64(counter)
counter, ok := value.(*int64)
if ok {
beforeCount = atomic.LoadInt64(counter)
}
return true
})
assert.Equal(t, int64(2), beforeCount)
@@ -171,8 +177,10 @@ func TestStatsBuffer_SwapInt64(t *testing.T) {
var afterCount int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
afterCount = atomic.LoadInt64(counter)
counter, ok := value.(*int64)
if ok {
afterCount = atomic.LoadInt64(counter)
}
return true
})
assert.Equal(t, int64(0), afterCount)
@@ -190,8 +198,10 @@ func TestStatsBuffer_FailRetry(t *testing.T) {
var count int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
count = atomic.LoadInt64(counter)
counter, ok := value.(*int64)
if ok {
count = atomic.LoadInt64(counter)
}
return true
})
assert.Equal(t, int64(2), count)

View File

@@ -1,6 +1,7 @@
package errors
import (
stderrors "errors"
"fmt"
"net/http"
)
@@ -70,22 +71,11 @@ func AsAppError(err error) (*AppError, bool) {
if err == nil {
return nil, false
}
var appErr *AppError
if ok := is(err, &appErr); ok {
return appErr, true
}
return nil, false
}
func is(err error, target interface{}) bool {
// 简单的类型断言
if e, ok := err.(*AppError); ok {
// 直接赋值
switch t := target.(type) {
case **AppError:
*t = e
return true
}
var appErr *AppError
if !stderrors.As(err, &appErr) {
return nil, false
}
return false
return appErr, true
}

View File

@@ -104,7 +104,8 @@ func TestPredefinedErrors(t *testing.T) {
func TestAsAppError(t *testing.T) {
t.Run("nil输入", func(t *testing.T) {
_, ok := AsAppError(nil)
appErr, ok := AsAppError(nil)
assert.Nil(t, appErr)
assert.False(t, ok)
})
@@ -122,7 +123,8 @@ func TestAsAppError(t *testing.T) {
})
t.Run("非AppError类型", func(t *testing.T) {
_, ok := AsAppError(errors.New("普通错误"))
appErr, ok := AsAppError(errors.New("普通错误"))
assert.Nil(t, appErr)
assert.False(t, ok)
})
}

View File

@@ -19,7 +19,7 @@ func TestNew_StdoutOnly(t *testing.T) {
func TestNew_WithFileOutput(t *testing.T) {
dir := filepath.Join(os.TempDir(), "nex-logger-test")
os.MkdirAll(dir, 0755)
require.NoError(t, os.MkdirAll(dir, 0o755))
defer os.RemoveAll(dir)
logger, err := New(Config{
@@ -81,7 +81,7 @@ func TestParseLevel(t *testing.T) {
{"info", true},
{"warn", true},
{"error", true},
{"", true}, // 默认为 info
{"", true}, // 默认为 info
{"invalid", true}, // 默认为 info
}
for _, tt := range tests {

View File

@@ -22,9 +22,9 @@ func newRotateWriter(cfg Config) *lumberjack.Logger {
return &lumberjack.Logger{
Filename: logFilePath(cfg.Path),
MaxSize: maxSize, // MB
MaxSize: maxSize, // MB
MaxBackups: maxBackups,
MaxAge: maxAge, // days
MaxAge: maxAge, // days
Compress: cfg.Compress,
}
}

View File

@@ -6,10 +6,10 @@ import (
"testing"
"time"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/config"
)
func TestLoadConfig_DefaultValues(t *testing.T) {
@@ -72,7 +72,7 @@ log:
max_age: 7
compress: false
`
err := os.WriteFile(configPath, []byte(yamlContent), 0644)
err := os.WriteFile(configPath, []byte(yamlContent), 0o600)
require.NoError(t, err)
cfg, err := config.LoadConfigFromPath(configPath)
@@ -103,7 +103,7 @@ server:
log:
level: warn
`
err := os.WriteFile(configPath, []byte(yamlContent), 0644)
err := os.WriteFile(configPath, []byte(yamlContent), 0o600)
require.NoError(t, err)
t.Setenv("NEX_SERVER_PORT", "9000")
@@ -147,7 +147,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
}
defer func() {
if originalConfig != nil {
_ = os.WriteFile(configPath, originalConfig, 0644)
require.NoError(t, os.WriteFile(configPath, originalConfig, 0o600))
}
}()

View File

@@ -11,20 +11,21 @@ import (
"testing"
"time"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/provider"
"nex/backend/internal/repository"
"nex/backend/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"gorm.io/gorm"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
openaiConv "nex/backend/internal/conversion/openai"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/provider"
"nex/backend/internal/repository"
"nex/backend/internal/service"
)
func init() {
@@ -39,7 +40,8 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 默认返回成功,由各测试 case 覆盖
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"error":"not mocked"}`))
_, err := w.Write([]byte(`{"error":"not mocked"}`))
require.NoError(t, err)
}))
db := setupTestDB(t)
@@ -124,7 +126,6 @@ func createProviderAndModel(t *testing.T, r *gin.Engine, providerID, protocol, m
require.Equal(t, 201, w.Code)
modelBody, _ := json.Marshal(map[string]string{
"provider_id": providerID,
"model_name": modelName,
})
@@ -143,9 +144,10 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
// 配置上游返回 Anthropic 格式响应
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求被转换为 Anthropic 格式
body, _ := io.ReadAll(r.Body)
body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any
json.Unmarshal(body, &req)
require.NoError(t, json.Unmarshal(body, &req))
assert.Equal(t, "/v1/messages", r.URL.Path)
assert.Contains(t, r.Header.Get("Content-Type"), "application/json")
@@ -166,7 +168,7 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
require.NoError(t, json.NewEncoder(w).Encode(resp))
})
createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
@@ -189,13 +191,16 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
assert.Equal(t, 200, w.Code)
var resp map[string]any
json.Unmarshal(w.Body.Bytes(), &resp)
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "chat.completion", resp["object"])
choices := resp["choices"].([]any)
choices, ok := resp["choices"].([]any)
require.True(t, ok)
require.Len(t, choices, 1)
choice := choices[0].(map[string]any)
msg := choice["message"].(map[string]any)
choice, ok := choices[0].(map[string]any)
require.True(t, ok)
msg, ok := choice["message"].(map[string]any)
require.True(t, ok)
assert.Contains(t, msg["content"], "Hello from Anthropic!")
}
@@ -203,9 +208,10 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
r, _, upstream := setupConversionTest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any
json.Unmarshal(body, &req)
require.NoError(t, json.Unmarshal(body, &req))
assert.Equal(t, "/chat/completions", r.URL.Path)
assert.Contains(t, r.Header.Get("Authorization"), "Bearer test-key")
@@ -229,7 +235,7 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
require.NoError(t, json.NewEncoder(w).Encode(resp))
})
createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
@@ -252,12 +258,14 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
assert.Equal(t, 200, w.Code)
var resp map[string]any
json.Unmarshal(w.Body.Bytes(), &resp)
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "message", resp["type"])
content := resp["content"].([]any)
content, ok := resp["content"].([]any)
require.True(t, ok)
require.Len(t, content, 1)
block := content[0].(map[string]any)
block, ok2 := content[0].(map[string]any)
require.True(t, ok2)
assert.Contains(t, block["text"], "Hello from OpenAI!")
}
@@ -269,21 +277,23 @@ func TestConversion_OpenAIToOpenAI_Passthrough(t *testing.T) {
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/chat/completions", r.URL.Path)
body, _ := io.ReadAll(r.Body)
body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any
json.Unmarshal(body, &req)
require.NoError(t, json.Unmarshal(body, &req))
// Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名
assert.Equal(t, "gpt-4", req["model"])
w.Header().Set("Content-Type", "application/json")
// 上游返回上游模型名
w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`))
_, err = w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`))
require.NoError(t, err)
})
createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
reqBody := map[string]any{
"model": "openai_p/gpt-4", // 客户端发送统一 ID
"model": "openai_p/gpt-4", // 客户端发送统一 ID
"messages": []map[string]any{{"role": "user", "content": "test"}},
}
body, _ := json.Marshal(reqBody)
@@ -304,21 +314,23 @@ func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) {
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/messages", r.URL.Path)
body, _ := io.ReadAll(r.Body)
body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req map[string]any
json.Unmarshal(body, &req)
require.NoError(t, json.Unmarshal(body, &req))
// Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名
assert.Equal(t, "claude-3-opus", req["model"])
// 上游返回上游模型名
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`))
_, err = w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`))
require.NoError(t, err)
})
createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
reqBody := map[string]any{
"model": "anthropic_p/claude-3-opus", // 客户端发送统一 ID
"model": "anthropic_p/claude-3-opus", // 客户端发送统一 ID
"max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "test"}},
}
@@ -352,7 +364,8 @@ func TestConversion_OpenAIToAnthropic_Stream(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
}
for _, e := range events {
w.Write([]byte(e))
_, err := w.Write([]byte(e))
require.NoError(t, err)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
@@ -393,7 +406,8 @@ func TestConversion_AnthropicToOpenAI_Stream(t *testing.T) {
"data: [DONE]\n\n",
}
for _, e := range events {
w.Write([]byte(e))
_, err := w.Write([]byte(e))
require.NoError(t, err)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
@@ -447,11 +461,13 @@ func TestConversion_Models_CrossProtocol(t *testing.T) {
require.NoError(t, err)
var anthropicResp map[string]any
json.Unmarshal(anthropicBody, &anthropicResp)
data := anthropicResp["data"].([]any)
require.NoError(t, json.Unmarshal(anthropicBody, &anthropicResp))
data, okd := anthropicResp["data"].([]any)
require.True(t, okd)
assert.Len(t, data, 2)
first := data[0].(map[string]any)
first, okf := data[0].(map[string]any)
require.True(t, okf)
assert.Equal(t, "gpt-4", first["id"])
assert.Equal(t, "model", first["type"])
@@ -466,11 +482,12 @@ func TestConversion_Models_CrossProtocol(t *testing.T) {
require.NoError(t, err)
var openaiResp map[string]any
json.Unmarshal(openaiBody, &err)
json.Unmarshal(openaiBody, &openaiResp)
oaiData := openaiResp["data"].([]any)
require.NoError(t, json.Unmarshal(openaiBody, &openaiResp))
oaiData, oki := openaiResp["data"].([]any)
require.True(t, oki)
assert.Len(t, oaiData, 1)
firstOai := oaiData[0].(map[string]any)
firstOai, okf2 := oaiData[0].(map[string]any)
require.True(t, okf2)
assert.Equal(t, "claude-3-opus", firstOai["id"])
}
@@ -537,7 +554,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
require.Equal(t, 201, w.Code)
var created map[string]any
json.Unmarshal(w.Body.Bytes(), &created)
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &created))
assert.Equal(t, "sk-test", created["api_key"])
// 获取时应包含 protocol
@@ -547,7 +564,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
assert.Equal(t, 200, w.Code)
var fetched map[string]any
json.Unmarshal(w.Body.Bytes(), &fetched)
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &fetched))
assert.Equal(t, "anthropic", fetched["protocol"])
}
@@ -570,11 +587,13 @@ func TestConversion_ProviderDefaultProtocol(t *testing.T) {
require.Equal(t, 201, w.Code)
var created map[string]any
json.Unmarshal(w.Body.Bytes(), &created)
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &created))
assert.Equal(t, "openai", created["protocol"])
}
// Suppress unused imports
var _ = fmt.Sprintf
var _ = strings.Contains
var _ = time.Second
var (
_ = fmt.Sprintf
_ = strings.Contains
_ = time.Second
)

View File

@@ -12,19 +12,20 @@ import (
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
openaiConv "nex/backend/internal/conversion/openai"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/provider"
"nex/backend/internal/repository"
"nex/backend/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
openaiConv "nex/backend/internal/conversion/openai"
)
func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
@@ -33,7 +34,8 @@ func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"error":"not mocked"}`))
_, err := w.Write([]byte(`{"error":"not mocked"}`))
require.NoError(t, err)
}))
db := setupTestDB(t)
@@ -115,11 +117,12 @@ func parseSSEEvents(body string) []map[string]string {
var currentEvent, currentData string
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "event: ") {
switch {
case strings.HasPrefix(line, "event: "):
currentEvent = strings.TrimPrefix(line, "event: ")
} else if strings.HasPrefix(line, "data: ") {
case strings.HasPrefix(line, "data: "):
currentData = strings.TrimPrefix(line, "data: ")
} else if line == "" && (currentEvent != "" || currentData != "") {
case line == "" && (currentEvent != "" || currentData != ""):
events = append(events, map[string]string{
"event": currentEvent,
"data": currentData,
@@ -157,21 +160,21 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) {
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.Equal(t, "/chat/completions", req.URL.Path)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-001",
"object": "chat.completion",
"created": 1700000000,
"model": "gpt-4o",
"choices": []map[string]any{{
"index": 0,
"message": map[string]any{"role": "assistant", "content": "你好我是AI助手。"},
"index": 0,
"message": map[string]any{"role": "assistant", "content": "你好我是AI助手。"},
"finish_reason": "stop",
"logprobs": nil,
"logprobs": nil,
}},
"usage": map[string]any{
"prompt_tokens": 15, "completion_tokens": 10, "total_tokens": 25,
},
})
}))
})
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -210,21 +213,23 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) {
func TestE2E_OpenAI_NonStream_MultiTurn(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body)
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any
json.Unmarshal(body, &reqBody)
msgs := reqBody["messages"].([]any)
require.NoError(t, json.Unmarshal(body, &reqBody))
msgs, ok := reqBody["messages"].([]any)
require.True(t, ok)
assert.GreaterOrEqual(t, len(msgs), 3)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-002", "object": "chat.completion", "created": 1700000001, "model": "gpt-4o",
"choices": []map[string]any{{
"index": 0, "message": map[string]any{"role": "assistant", "content": "Go语言的interface是隐式实现的。"},
"finish_reason": "stop", "logprobs": nil,
}},
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
})
}))
})
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -252,7 +257,7 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-004", "object": "chat.completion", "created": 1700000003, "model": "gpt-4o",
"choices": []map[string]any{{
"index": 0,
@@ -272,7 +277,7 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
"logprobs": nil,
}},
"usage": map[string]any{"prompt_tokens": 80, "completion_tokens": 18, "total_tokens": 98},
})
}))
})
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -286,9 +291,9 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
"function": map[string]any{
"name": "get_weather", "description": "获取天气",
"parameters": map[string]any{
"type": "object",
"type": "object",
"properties": map[string]any{"city": map[string]any{"type": "string"}},
"required": []string{"city"},
"required": []string{"city"},
},
},
}},
@@ -319,22 +324,22 @@ func TestE2E_OpenAI_NonStream_MaxTokens_Length(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-014", "object": "chat.completion", "created": 1700000014, "model": "gpt-4o",
"choices": []map[string]any{{
"index": 0,
"message": map[string]any{"role": "assistant", "content": "人工智能起源于1950年代..."},
"index": 0,
"message": map[string]any{"role": "assistant", "content": "人工智能起源于1950年代..."},
"finish_reason": "length",
"logprobs": nil,
}},
"usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 30, "total_tokens": 50},
})
}))
})
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}},
"model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}},
"max_tokens": 30,
})
w := httptest.NewRecorder()
@@ -353,11 +358,11 @@ func TestE2E_OpenAI_NonStream_UsageWithReasoning(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-022", "object": "chat.completion", "created": 1700000022, "model": "o3",
"choices": []map[string]any{{
"index": 0,
"message": map[string]any{"role": "assistant", "content": "答案是61。"},
"index": 0,
"message": map[string]any{"role": "assistant", "content": "答案是61。"},
"finish_reason": "stop",
"logprobs": nil,
}},
@@ -365,12 +370,12 @@ func TestE2E_OpenAI_NonStream_UsageWithReasoning(t *testing.T) {
"prompt_tokens": 35, "completion_tokens": 48, "total_tokens": 83,
"completion_tokens_details": map[string]any{"reasoning_tokens": 20},
},
})
}))
})
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "o3", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "openai_p/o3",
"model": "openai_p/o3",
"messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}},
})
w := httptest.NewRecorder()
@@ -393,12 +398,12 @@ func TestE2E_OpenAI_NonStream_Refusal(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-007", "object": "chat.completion", "created": 1700000007, "model": "gpt-4o",
"choices": []map[string]any{{
"index": 0,
"message": map[string]any{
"role": "assistant",
"role": "assistant",
"content": nil,
"refusal": "抱歉,我无法提供涉及危险活动的信息。",
},
@@ -406,12 +411,12 @@ func TestE2E_OpenAI_NonStream_Refusal(t *testing.T) {
"logprobs": nil,
}},
"usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 35, "total_tokens": 47},
})
}))
})
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o",
"model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "做坏事"}},
})
w := httptest.NewRecorder()
@@ -453,9 +458,9 @@ func TestE2E_OpenAI_Stream_Text(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o",
"model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "你好"}},
"stream": true,
"stream": true,
})
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
@@ -497,14 +502,14 @@ func TestE2E_OpenAI_Stream_ToolCalls(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o",
"model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
"tools": []map[string]any{{
"type": "function",
"function": map[string]any{
"name": "get_weather", "description": "获取天气",
"parameters": map[string]any{
"type": "object",
"type": "object",
"properties": map[string]any{"city": map[string]any{"type": "string"}},
},
},
@@ -546,9 +551,9 @@ func TestE2E_OpenAI_Stream_WithUsage(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o",
"model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "hi"}},
"stream": true,
"stream": true,
})
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
@@ -569,14 +574,14 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_001", "type": "message", "role": "assistant",
"content": []map[string]any{
{"type": "text", "text": "你好我是Claude由Anthropic开发的AI助手。"},
},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 15, "output_tokens": 25},
})
}))
})
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -611,24 +616,25 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) {
func TestE2E_Anthropic_NonStream_WithSystem(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body)
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any
json.Unmarshal(body, &reqBody)
require.NoError(t, json.Unmarshal(body, &reqBody))
assert.NotNil(t, reqBody["system"])
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_003", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "递归是函数调用自身。"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 30, "output_tokens": 15},
})
}))
})
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"system": "你是编程助手",
"system": "你是编程助手",
"messages": []map[string]any{{"role": "user", "content": "什么是递归?"}},
})
w := httptest.NewRecorder()
@@ -643,7 +649,7 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_009", "type": "message", "role": "assistant",
"content": []map[string]any{{
"type": "tool_use", "id": "toolu_e2e_009", "name": "get_weather",
@@ -651,7 +657,7 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
}},
"model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 180, "output_tokens": 42},
})
}))
})
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -661,9 +667,9 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
"tools": []map[string]any{{
"name": "get_weather", "description": "获取天气",
"input_schema": map[string]any{
"type": "object",
"type": "object",
"properties": map[string]any{"city": map[string]any{"type": "string"}},
"required": []string{"city"},
"required": []string{"city"},
},
}},
"tool_choice": map[string]any{"type": "auto"},
@@ -689,7 +695,7 @@ func TestE2E_Anthropic_NonStream_Thinking(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_018", "type": "message", "role": "assistant",
"content": []map[string]any{
{"type": "thinking", "thinking": "这是一个逻辑推理问题..."},
@@ -697,7 +703,7 @@ func TestE2E_Anthropic_NonStream_Thinking(t *testing.T) {
},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 95, "output_tokens": 280},
})
}))
})
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -724,12 +730,12 @@ func TestE2E_Anthropic_NonStream_MaxTokens(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_016", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "人工智能起源于..."}},
"model": "claude-opus-4-7", "stop_reason": "max_tokens", "stop_sequence": nil,
"model": "claude-opus-4-7", "stop_reason": "max_tokens", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 22, "output_tokens": 20},
})
}))
})
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -752,18 +758,18 @@ func TestE2E_Anthropic_NonStream_StopSequence(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_017", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "1\n2\n3\n4\n"}},
"model": "claude-opus-4-7", "stop_reason": "stop_sequence", "stop_sequence": "5",
"model": "claude-opus-4-7", "stop_reason": "stop_sequence", "stop_sequence": "5",
"usage": map[string]any{"input_tokens": 22, "output_tokens": 10},
})
}))
})
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
"stop_sequences": []string{"5"},
})
w := httptest.NewRecorder()
@@ -781,19 +787,20 @@ func TestE2E_Anthropic_NonStream_StopSequence(t *testing.T) {
func TestE2E_Anthropic_NonStream_MetadataUserID(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body)
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any
json.Unmarshal(body, &reqBody)
require.NoError(t, json.Unmarshal(body, &reqBody))
metadata, _ := reqBody["metadata"].(map[string]any)
assert.Equal(t, "user_12345", metadata["user_id"])
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_026", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "你好!"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 12, "output_tokens": 5},
})
}))
})
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -814,21 +821,21 @@ func TestE2E_Anthropic_NonStream_UsageWithCache(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_025", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "你好!"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{
"input_tokens": 25, "output_tokens": 5,
"cache_creation_input_tokens": 15, "cache_read_input_tokens": 0,
},
})
}))
})
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"system": []map[string]any{{"type": "text", "text": "你是编程助手。"}},
"system": []map[string]any{{"type": "text", "text": "你是编程助手。"}},
"messages": []map[string]any{{"role": "user", "content": "你好"}},
})
w := httptest.NewRecorder()
@@ -864,7 +871,8 @@ func TestE2E_Anthropic_Stream_Text(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
}
for _, e := range events {
w.Write([]byte(e))
_, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush()
time.Sleep(10 * time.Millisecond)
}
@@ -874,7 +882,7 @@ func TestE2E_Anthropic_Stream_Text(t *testing.T) {
body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "你好"}},
"stream": true,
"stream": true,
})
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
@@ -922,7 +930,7 @@ func TestE2E_Anthropic_Stream_Thinking(t *testing.T) {
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 4096,
"messages": []map[string]any{{"role": "user", "content": "1+1=?"}},
"thinking": map[string]any{"type": "enabled", "budget_tokens": 1024},
"stream": true,
"stream": true,
})
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
@@ -961,14 +969,14 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_RequestFormat(t *testing.T) {
json.NewEncoder(w).Encode(map[string]any{
"id": "msg_cross_001", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "跨协议响应"}},
"model": "claude-model", "stop_reason": "end_turn", "stop_sequence": nil,
"model": "claude-model", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
})
})
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-model",
"model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
})
w := httptest.NewRecorder()
@@ -1050,9 +1058,9 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_Stream(t *testing.T) {
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-model",
"model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
"stream": true,
"stream": true,
})
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
@@ -1092,7 +1100,7 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream(t *testing.T) {
body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
"stream": true,
"stream": true,
})
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
@@ -1128,7 +1136,7 @@ func TestE2E_OpenAI_ErrorResponse(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "nonexistent", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "openai_p/nonexistent",
"model": "openai_p/nonexistent",
"messages": []map[string]any{{"role": "user", "content": "test"}},
})
w := httptest.NewRecorder()
@@ -1183,11 +1191,11 @@ func TestE2E_OpenAI_NonStream_ParallelToolCalls(t *testing.T) {
"content": nil,
"tool_calls": []map[string]any{
{
"id": "call_ptc_1", "type": "function",
"id": "call_ptc_1", "type": "function",
"function": map[string]any{"name": "get_weather", "arguments": `{"city":"北京"}`},
},
{
"id": "call_ptc_2", "type": "function",
"id": "call_ptc_2", "type": "function",
"function": map[string]any{"name": "get_weather", "arguments": `{"city":"上海"}`},
},
},
@@ -1201,7 +1209,7 @@ func TestE2E_OpenAI_NonStream_ParallelToolCalls(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o",
"model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}},
"tools": []map[string]any{{
"type": "function",
@@ -1242,10 +1250,10 @@ func TestE2E_OpenAI_NonStream_StopSequence(t *testing.T) {
json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-stop", "object": "chat.completion", "created": 1700000060, "model": "gpt-4o",
"choices": []map[string]any{{
"index": 0,
"message": map[string]any{"role": "assistant", "content": "1, 2, 3, 4, "},
"finish_reason": "stop",
"logprobs": nil,
"index": 0,
"message": map[string]any{"role": "assistant", "content": "1, 2, 3, 4, "},
"finish_reason": "stop",
"logprobs": nil,
}},
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18},
})
@@ -1253,9 +1261,9 @@ func TestE2E_OpenAI_NonStream_StopSequence(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o",
"model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
"stop": []string{"5"},
"stop": []string{"5"},
})
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
@@ -1291,7 +1299,7 @@ func TestE2E_OpenAI_NonStream_ContentFilter(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o",
"model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "危险内容"}},
})
w := httptest.NewRecorder()
@@ -1353,21 +1361,22 @@ func TestE2E_Anthropic_NonStream_MultiToolUse(t *testing.T) {
func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body)
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any
json.Unmarshal(body, &reqBody)
require.NoError(t, json.Unmarshal(body, &reqBody))
tc, _ := reqBody["tool_choice"].(map[string]any)
assert.Equal(t, "any", tc["type"])
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_tca", "type": "message", "role": "assistant",
"content": []map[string]any{
{"type": "tool_use", "id": "toolu_tca_1", "name": "get_time", "input": map[string]any{"timezone": "Asia/Shanghai"}},
},
"model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 100, "output_tokens": 30},
})
}))
})
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -1397,20 +1406,21 @@ func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) {
func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body)
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any
json.Unmarshal(body, &reqBody)
require.NoError(t, json.Unmarshal(body, &reqBody))
sys, ok := reqBody["system"].([]any)
require.True(t, ok, "system should be an array")
require.GreaterOrEqual(t, len(sys), 1)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_asys", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "已收到多条系统指令。"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 50, "output_tokens": 10},
})
}))
})
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -1433,21 +1443,22 @@ func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) {
func TestE2E_Anthropic_NonStream_ToolResultMessage(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body)
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any
json.Unmarshal(body, &reqBody)
require.NoError(t, json.Unmarshal(body, &reqBody))
msgs := reqBody["messages"].([]any)
require.GreaterOrEqual(t, len(msgs), 3)
lastMsg := msgs[len(msgs)-1].(map[string]any)
assert.Equal(t, "user", lastMsg["role"])
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "msg_e2e_tr", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "北京当前晴天温度25°C。"}},
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 150, "output_tokens": 20},
})
}))
})
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
@@ -1497,7 +1508,8 @@ func TestE2E_Anthropic_Stream_ToolCalls(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
}
for _, e := range events {
w.Write([]byte(e))
_, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush()
time.Sleep(10 * time.Millisecond)
}
@@ -1559,7 +1571,7 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_NonStream_ToolCalls(t *testing.T) {
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-model",
"model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
"tools": []map[string]any{{
"type": "function",
@@ -1634,14 +1646,14 @@ func TestE2E_CrossProtocol_StopReasonMapping(t *testing.T) {
json.NewEncoder(w).Encode(map[string]any{
"id": "msg_cross_stop", "type": "message", "role": "assistant",
"content": []map[string]any{{"type": "text", "text": "被截断的内容..."}},
"model": "claude-model", "stop_reason": "max_tokens", "stop_sequence": nil,
"model": "claude-model", "stop_reason": "max_tokens", "stop_sequence": nil,
"usage": map[string]any{"input_tokens": 10, "output_tokens": 20},
})
})
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-model",
"model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "长文"}},
})
w := httptest.NewRecorder()
@@ -1659,9 +1671,10 @@ func TestE2E_CrossProtocol_StopReasonMapping(t *testing.T) {
func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) {
r, upstream := setupE2ETest(t)
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body)
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
var reqBody map[string]any
json.Unmarshal(body, &reqBody)
require.NoError(t, json.Unmarshal(body, &reqBody))
msgs := reqBody["messages"].([]any)
require.GreaterOrEqual(t, len(msgs), 3)
toolMsg := msgs[2].(map[string]any)
@@ -1669,16 +1682,16 @@ func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) {
assert.Equal(t, "call_e2e_001", toolMsg["tool_call_id"])
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-e2e-tr", "object": "chat.completion", "created": 1700000080, "model": "gpt-4o",
"choices": []map[string]any{{
"index": 0,
"message": map[string]any{"role": "assistant", "content": "北京当前晴天温度25°C。"},
"index": 0,
"message": map[string]any{"role": "assistant", "content": "北京当前晴天温度25°C。"},
"finish_reason": "stop",
"logprobs": nil,
}},
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
})
}))
})
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
@@ -1722,7 +1735,8 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream_ToolCalls(t *testing.T) {
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
}
for _, e := range events {
w.Write([]byte(e))
_, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush()
time.Sleep(10 * time.Millisecond)
}
@@ -1730,7 +1744,7 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream_ToolCalls(t *testing.T) {
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-model",
"model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
"tools": []map[string]any{{
"type": "function",
@@ -1817,7 +1831,7 @@ func TestE2E_OpenAI_Upstream5xx_ErrorPassthrough(t *testing.T) {
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{
"model": "openai_p/gpt-4o",
"model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "test"}},
})
w := httptest.NewRecorder()
@@ -1879,7 +1893,8 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"正常\"}}\n\n",
}
for _, e := range events {
w.Write([]byte(e))
_, err := w.Write([]byte(e))
require.NoError(t, err)
flusher.Flush()
time.Sleep(10 * time.Millisecond)
}
@@ -1889,7 +1904,7 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
body, _ := json.Marshal(map[string]any{
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "test"}},
"stream": true,
"stream": true,
})
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
@@ -1902,5 +1917,7 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
assert.Contains(t, respBody, "正常")
}
var _ = fmt.Sprintf
var _ = time.Now
var (
_ = fmt.Sprintf
_ = time.Now
)

View File

@@ -7,16 +7,17 @@ import (
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"gorm.io/gorm"
"nex/backend/internal/domain"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/repository"
"nex/backend/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"gorm.io/gorm"
)
func init() {
@@ -97,7 +98,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
var createdModel domain.Model
json.Unmarshal(w.Body.Bytes(), &createdModel)
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &createdModel))
assert.NotEmpty(t, createdModel.ID)
// 3. 列出 Provider
@@ -106,7 +107,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
var providers []domain.Provider
json.Unmarshal(w.Body.Bytes(), &providers)
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &providers))
assert.Len(t, providers, 1)
assert.Equal(t, "sk-test-key", providers[0].APIKey)
@@ -116,7 +117,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
var models []domain.Model
json.Unmarshal(w.Body.Bytes(), &models)
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &models))
assert.Len(t, models, 1)
assert.Equal(t, "gpt-4", models[0].ModelName)
@@ -163,7 +164,7 @@ func TestAnthropic_ModelCreation(t *testing.T) {
r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
var createdModel domain.Model
json.Unmarshal(w.Body.Bytes(), &createdModel)
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &createdModel))
// 验证创建成功
w = httptest.NewRecorder()
@@ -194,9 +195,9 @@ func TestStats_RecordingAndQuery(t *testing.T) {
// 直接通过 repository 记录统计(模拟代理请求后的统计记录)
statsRepo := repository.NewStatsRepository(db)
statsRepo.Record("p1", "gpt-4")
statsRepo.Record("p1", "gpt-4")
statsRepo.Record("p1", "gpt-4")
require.NoError(t, statsRepo.Record("p1", "gpt-4"))
require.NoError(t, statsRepo.Record("p1", "gpt-4"))
require.NoError(t, statsRepo.Record("p1", "gpt-4"))
// 查询统计
w = httptest.NewRecorder()
@@ -205,7 +206,7 @@ func TestStats_RecordingAndQuery(t *testing.T) {
assert.Equal(t, 200, w.Code)
var stats []domain.UsageStats
json.Unmarshal(w.Body.Bytes(), &stats)
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &stats))
assert.Len(t, stats, 1)
assert.Equal(t, 3, stats[0].RequestCount)

View File

@@ -4,11 +4,11 @@ import (
"testing"
"time"
"nex/backend/internal/config"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"nex/backend/internal/config"
)
// setupTestDB 创建内存 SQLite 数据库并执行 AutoMigrate。

View File

@@ -3,9 +3,10 @@ package tests
import (
"testing"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/config"
)
func TestMigration_ModelsUUIDPrimaryKey(t *testing.T) {

View File

@@ -10,9 +10,10 @@
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock"
)

View File

@@ -10,9 +10,10 @@
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock"
)

View File

@@ -11,9 +11,10 @@ package mocks
import (
context "context"
reflect "reflect"
conversion "nex/backend/internal/conversion"
provider "nex/backend/internal/provider"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)

View File

@@ -10,9 +10,10 @@
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock"
)

View File

@@ -10,9 +10,10 @@
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock"
)

View File

@@ -10,9 +10,10 @@
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock"
)

View File

@@ -10,10 +10,11 @@
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
time "time"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock"
)

View File

@@ -10,10 +10,11 @@
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
time "time"
domain "nex/backend/internal/domain"
gomock "go.uber.org/mock/gomock"
)

View File

@@ -90,7 +90,7 @@ func TestConstraint_UniqueProviderModel(t *testing.T) {
}
err = db.Create(&model2).Error
assert.Error(t, err, "创建相同 (provider_id, model_name) 的 model 应失败")
assert.True(t, errors.Is(err, gorm.ErrDuplicatedKey) ||
assert.True(t, errors.Is(err, gorm.ErrDuplicatedKey) ||
(err != nil && (err.Error() == "Error 1062" || containsDuplicateError(err.Error()))),
"错误应为唯一约束错误")
}
@@ -120,7 +120,7 @@ func TestConstraint_UniqueUsageStats(t *testing.T) {
}
err = db.Create(&stats2).Error
assert.Error(t, err, "创建相同 (provider_id, model_name, date) 的 usage_stats 应失败")
assert.True(t, errors.Is(err, gorm.ErrDuplicatedKey) ||
assert.True(t, errors.Is(err, gorm.ErrDuplicatedKey) ||
(err != nil && (err.Error() == "Error 1062" || containsDuplicateError(err.Error()))),
"错误应为唯一约束错误")
}

5
lefthook.yml Normal file
View File

@@ -0,0 +1,5 @@
pre-commit:
commands:
backend-lint:
glob: "backend/**/*.go"
run: cd backend && go tool golangci-lint run --new-from-rev HEAD ./...

View File

@@ -169,15 +169,15 @@
- **WHEN** 应用启动
- **THEN** SHALL 按以下顺序加载配置:
1. 解析 CLI 参数(获取 --config 路径)
2. 初始化配置管理器
3. 设置默认值
4. 绑定 CLI 参数
5. 绑定环境变量
6. 读取配置文件
7. 反序列化到结构体
8. 验证配置
9. 打印配置摘要
1. 解析 CLI 参数(获取 --config 路径)
2. 初始化配置管理器
3. 设置默认值
4. 绑定 CLI 参数
5. 绑定环境变量
6. 读取配置文件(不存在时自动创建)
7. 反序列化到结构体
8. 验证配置
9. 打印配置摘要
#### Scenario: 加载失败处理

View File

@@ -31,16 +31,27 @@
- **THEN** SHALL 测试请求转换、响应转换、流式转换
- **THEN** SHALL 验证转换的准确性和完整性
#### Scenario: config 加载管道集成测试
#### Scenario: LoadConfigFromPath 默认值验证
- **WHEN** 运行 config 加载管道的集成测试
- **THEN** SHALL 验证 LoadConfigFromPath 正确加载默认值
- **THEN** SHALL 验证环境变量(`NEX_` 前缀)覆盖默认值
- **THEN** SHALL 验证 YAML 配置文件正确读取
- **THEN** SHALL 验证优先级链CLI 参数 > 环境变量 > YAML 文件 > 默认值
- **THEN** SHALL 验证首次启动自动创建配置文件
- **THEN** SHALL 验证 SaveConfig 后重新 LoadConfig 数据一致
#### Scenario: 环境变量覆盖验证
- **WHEN** 设置 `NEX_SERVER_PORT=9000``NEX_LOG_LEVEL=debug`
- **THEN** SHALL 成功加载
- **THEN** 配置值 SHALL 反映环境变量覆盖
#### Scenario: 自动创建配置文件验证
- **WHEN** 调用 `LoadConfigFromPath` 并指向不存在的文件路径
- **THEN** SHALL 成功加载(不返回 `missing configuration for 'configPath'` 错误)
- **THEN** SHALL 返回默认配置对象
#### Scenario: handler 错误分支测试
- **WHEN** 运行 handler 层的单元测试