chore: 合并 dev-code-backend-format 到 master
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -401,6 +401,7 @@ cython_debug/
|
||||
# Custom
|
||||
.claude
|
||||
.opencode
|
||||
.codex
|
||||
openspec/changes/archive
|
||||
temp
|
||||
.agents
|
||||
|
||||
12
Makefile
12
Makefile
@@ -1,11 +1,11 @@
|
||||
.PHONY: all dev build test lint clean \
|
||||
backend-build backend-run backend-dev backend-test backend-test-unit backend-test-integration backend-test-coverage \
|
||||
backend-build backend-run backend-dev backend-test backend-test-all backend-test-unit backend-test-integration backend-test-coverage \
|
||||
backend-lint backend-clean backend-deps backend-generate \
|
||||
backend-db-up backend-db-down backend-db-status backend-db-create \
|
||||
test-mysql-up test-mysql-down test-mysql test-mysql-quick \
|
||||
frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint frontend-clean \
|
||||
desktop-build desktop-build-mac desktop-build-win desktop-build-linux \
|
||||
desktop-dev desktop-package-mac desktop-package-win desktop-package-linux desktop-clean \
|
||||
desktop-dev desktop-test desktop-package-mac desktop-package-win desktop-package-linux desktop-clean \
|
||||
desktop-prepare-frontend desktop-prepare-embedfs
|
||||
|
||||
# ============================================
|
||||
@@ -19,7 +19,7 @@ dev:
|
||||
build: backend-build frontend-build
|
||||
@echo "✅ Build complete"
|
||||
|
||||
test: backend-test frontend-test
|
||||
test: backend-test desktop-test frontend-test
|
||||
@echo "✅ All tests passed"
|
||||
|
||||
lint: backend-lint frontend-lint
|
||||
@@ -41,6 +41,9 @@ backend-dev:
|
||||
cd backend && go run ./cmd/server
|
||||
|
||||
backend-test:
|
||||
cd backend && go test ./internal/... ./pkg/... ./tests/... ./cmd/server/... -v
|
||||
|
||||
backend-test-all:
|
||||
cd backend && go test ./... -v
|
||||
|
||||
backend-test-unit:
|
||||
@@ -179,6 +182,9 @@ desktop-dev: desktop-prepare-frontend desktop-prepare-embedfs
|
||||
@echo "🖥️ Starting desktop app in dev mode..."
|
||||
cd backend && go run ./cmd/desktop
|
||||
|
||||
desktop-test:
|
||||
cd backend && go test ./cmd/desktop/... -v
|
||||
|
||||
desktop-package-mac:
|
||||
./scripts/build/package-macos.sh
|
||||
|
||||
|
||||
@@ -294,6 +294,9 @@ make frontend-test-coverage # 前端覆盖率
|
||||
## 开发
|
||||
|
||||
```bash
|
||||
# 首次克隆后安装 Git hooks
|
||||
lefthook install
|
||||
|
||||
# 顶层便捷命令
|
||||
make dev # 启动开发环境(并行启动后端和前端)
|
||||
make build # 构建所有产物
|
||||
|
||||
91
backend/.golangci.yml
Normal file
91
backend/.golangci.yml
Normal file
@@ -0,0 +1,91 @@
|
||||
run:
|
||||
timeout: 5m
|
||||
tests: true
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
- forbidigo
|
||||
- errorlint
|
||||
- errcheck
|
||||
- staticcheck
|
||||
- revive
|
||||
- gocritic
|
||||
- gosec
|
||||
- bodyclose
|
||||
- noctx
|
||||
- nilerr
|
||||
- goimports
|
||||
- gocyclo
|
||||
|
||||
linters-settings:
|
||||
errcheck:
|
||||
check-blank: true
|
||||
check-type-assertions: true
|
||||
exclude-functions:
|
||||
- fmt.Fprintf
|
||||
forbidigo:
|
||||
analyze-types: true
|
||||
forbid:
|
||||
- p: '^fmt\.Print.*$'
|
||||
msg: 使用 zap logger,不要直接输出到 stdout/stderr
|
||||
- p: '^fmt\.Fprint.*$'
|
||||
msg: 使用 zap logger,不要直接输出到 stdout/stderr
|
||||
- p: '^log\.(Print|Println|Printf|Fatal|Fatalln|Fatalf|Panic|Panicln|Panicf)$'
|
||||
msg: 使用 zap logger,不要使用标准库 log
|
||||
- p: '^zap\.L$'
|
||||
msg: 通过依赖注入传递 *zap.Logger,不要使用全局 logger
|
||||
- p: '^zap\.S$'
|
||||
msg: 不使用 Sugar logger
|
||||
revive:
|
||||
rules:
|
||||
- name: exported
|
||||
- name: var-naming
|
||||
- name: indent-error-flow
|
||||
- name: error-strings
|
||||
- name: error-return
|
||||
- name: blank-imports
|
||||
- name: context-as-argument
|
||||
- name: unexported-return
|
||||
goimports:
|
||||
local-prefixes: nex/backend
|
||||
gocyclo:
|
||||
min-complexity: 10
|
||||
|
||||
issues:
|
||||
exclude-dirs:
|
||||
- tests/mocks
|
||||
exclude-generated: true
|
||||
exclude-rules:
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- forbidigo
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- errcheck
|
||||
source: '(^\s*_\s*=|,\s*_)'
|
||||
- path: 'tests/integration/e2e_conversion_test\.go'
|
||||
linters:
|
||||
- errcheck
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- revive
|
||||
text: '^exported:'
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- gosec
|
||||
text: 'G(101|401|501)'
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- gocyclo
|
||||
text: 'cyclomatic complexity (1[1-9]|20) of .* is high \(> 10\)'
|
||||
- linters:
|
||||
- revive
|
||||
text: '(that stutters|BuildUrl should be BuildURL|ConvertHttpRequest should be ConvertHTTPRequest|ConvertHttpResponse should be ConvertHTTPResponse)'
|
||||
- path: 'internal/conversion/.*\.go'
|
||||
linters:
|
||||
- gocyclo
|
||||
- gocritic
|
||||
- path: '(internal/provider/client\.go|internal/service/model_service_impl\.go|internal/service/stats_buffer\.go|internal/handler/proxy_handler\.go|cmd/(desktop|server)/main\.go)'
|
||||
linters:
|
||||
- gocyclo
|
||||
@@ -609,6 +609,7 @@ err := v.Validate(myStruct)
|
||||
|
||||
- **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节
|
||||
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接
|
||||
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配
|
||||
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配(lint 强约束:errorlint 禁止 `err ==` 直接比较和 `err.(*T)` 直接断言)
|
||||
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
|
||||
- **日志 error 字段**:使用 `zap.Error(err)`,不使用 `zap.String("error", err.Error())` 手工字符串化
|
||||
- **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片
|
||||
|
||||
@@ -6,19 +6,25 @@ import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func showError(title, message string) {
|
||||
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`,
|
||||
escapeAppleScript(message), escapeAppleScript(title))
|
||||
exec.Command("osascript", "-e", script).Run()
|
||||
if err := exec.Command("osascript", "-e", script).Run(); err != nil {
|
||||
dialogLogger().Warn("显示错误对话框失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func showAbout() {
|
||||
message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway"
|
||||
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "关于 Nex Gateway"`,
|
||||
escapeAppleScript(message))
|
||||
exec.Command("osascript", "-e", script).Run()
|
||||
if err := exec.Command("osascript", "-e", script).Run(); err != nil {
|
||||
dialogLogger().Warn("显示关于对话框失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func escapeAppleScript(s string) string {
|
||||
|
||||
@@ -4,7 +4,6 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
)
|
||||
@@ -63,7 +62,7 @@ func showError(title, message string) {
|
||||
exec.Command("xmessage", "-center",
|
||||
fmt.Sprintf("%s: %s", title, message)).Run()
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "错误: %s: %s\n", title, message)
|
||||
dialogLogger().Error("无法显示错误对话框")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,6 +82,6 @@ func showAbout() {
|
||||
exec.Command("xmessage", "-center",
|
||||
fmt.Sprintf("关于 Nex Gateway: %s", message)).Run()
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "关于 Nex Gateway: %s\n", message)
|
||||
dialogLogger().Info("关于 Nex Gateway")
|
||||
}
|
||||
}
|
||||
|
||||
15
backend/cmd/desktop/dialog_logger.go
Normal file
15
backend/cmd/desktop/dialog_logger.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
func dialogLogger() *zap.Logger {
|
||||
if zapLogger != nil {
|
||||
return zapLogger
|
||||
}
|
||||
|
||||
return pkgLogger.NewMinimal()
|
||||
}
|
||||
@@ -13,10 +13,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/getlantern/systray"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/flock"
|
||||
"go.uber.org/zap"
|
||||
"nex/embedfs"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/conversion"
|
||||
@@ -28,9 +25,13 @@ import (
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
|
||||
"nex/embedfs"
|
||||
"github.com/getlantern/systray"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/flock"
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -51,12 +52,16 @@ func main() {
|
||||
showError("Nex Gateway", "已有 Nex 实例运行")
|
||||
os.Exit(1)
|
||||
}
|
||||
defer singleLock.Unlock()
|
||||
defer func() {
|
||||
if err := singleLock.Unlock(); err != nil {
|
||||
minimalLogger.Warn("释放实例锁失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := checkPortAvailable(port); err != nil {
|
||||
minimalLogger.Error("端口不可用", zap.Error(err))
|
||||
showError("Nex Gateway", err.Error())
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig()
|
||||
@@ -75,7 +80,11 @@ func main() {
|
||||
if err != nil {
|
||||
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
|
||||
}
|
||||
defer zapLogger.Sync()
|
||||
defer func() {
|
||||
if err := zapLogger.Sync(); err != nil {
|
||||
minimalLogger.Warn("同步日志失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
cfg.PrintSummary(zapLogger)
|
||||
|
||||
@@ -144,14 +153,14 @@ func main() {
|
||||
go func() {
|
||||
zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr))
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("服务器启动失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
|
||||
zapLogger.Warn("无法打开浏览器", zap.String("error", err.Error()))
|
||||
zapLogger.Warn("无法打开浏览器", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -193,7 +202,7 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
|
||||
func setupStaticFiles(r *gin.Engine) {
|
||||
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
||||
if err != nil {
|
||||
zapLogger.Fatal("无法加载前端资源", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("无法加载前端资源", zap.Error(err))
|
||||
}
|
||||
|
||||
getContentType := func(path string) string {
|
||||
@@ -266,7 +275,7 @@ func setupSystray(port int) {
|
||||
icon, err = embedfs.Assets.ReadFile("assets/icon.png")
|
||||
}
|
||||
if err != nil {
|
||||
zapLogger.Error("无法加载托盘图标", zap.String("error", err.Error()))
|
||||
zapLogger.Error("无法加载托盘图标", zap.Error(err))
|
||||
}
|
||||
systray.SetIcon(icon)
|
||||
systray.SetTitle("Nex Gateway")
|
||||
@@ -287,7 +296,9 @@ func setupSystray(port int) {
|
||||
for {
|
||||
select {
|
||||
case <-mOpen.ClickedCh:
|
||||
openBrowser(fmt.Sprintf("http://localhost:%d", port))
|
||||
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
|
||||
zapLogger.Warn("打开浏览器失败", zap.Error(err))
|
||||
}
|
||||
case <-mAbout.ClickedCh:
|
||||
showAbout()
|
||||
case <-mQuit.ClickedCh:
|
||||
@@ -308,7 +319,9 @@ func doShutdown() {
|
||||
if server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
server.Shutdown(ctx)
|
||||
if err := server.Shutdown(ctx); err != nil && zapLogger != nil {
|
||||
zapLogger.Warn("关闭服务器失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
if shutdownCancel != nil {
|
||||
@@ -346,8 +359,8 @@ func (s *SingletonLock) Lock() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SingletonLock) Unlock() {
|
||||
s.flock.Unlock()
|
||||
func (s *SingletonLock) Unlock() error {
|
||||
return s.flock.Unlock()
|
||||
}
|
||||
|
||||
func openBrowser(url string) error {
|
||||
|
||||
@@ -21,7 +21,7 @@ func TestCheckPortAvailable(t *testing.T) {
|
||||
func TestCheckPortOccupied(t *testing.T) {
|
||||
port := 19827
|
||||
|
||||
listener, err := net.Listen("tcp", ":19827")
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:19827")
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
@@ -47,13 +47,19 @@ func TestCheckPortOccupied(t *testing.T) {
|
||||
func TestCheckPortAvailableAfterClose(t *testing.T) {
|
||||
port := 19828
|
||||
|
||||
listener, err := net.Listen("tcp", ":19828")
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:19828")
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
|
||||
server := &http.Server{}
|
||||
go server.Serve(listener)
|
||||
server := &http.Server{ReadHeaderTimeout: time.Second}
|
||||
defer server.Close()
|
||||
go func() {
|
||||
err := server.Serve(listener)
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
t.Errorf("serve failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
|
||||
@@ -14,7 +14,11 @@ func TestSingletonLock_FirstLockSuccess(t *testing.T) {
|
||||
if err := lock.Lock(); err != nil {
|
||||
t.Fatalf("首次加锁应成功,但返回错误: %v", err)
|
||||
}
|
||||
defer lock.Unlock()
|
||||
defer func() {
|
||||
if err := lock.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func TestSingletonLock_DuplicateLockFails(t *testing.T) {
|
||||
@@ -25,12 +29,18 @@ func TestSingletonLock_DuplicateLockFails(t *testing.T) {
|
||||
if err := lock1.Lock(); err != nil {
|
||||
t.Fatalf("首次加锁应成功: %v", err)
|
||||
}
|
||||
defer lock1.Unlock()
|
||||
defer func() {
|
||||
if err := lock1.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
lock2 := NewSingletonLock(lockPath)
|
||||
err := lock2.Lock()
|
||||
if err == nil {
|
||||
lock2.Unlock()
|
||||
if unlockErr := lock2.Unlock(); unlockErr != nil {
|
||||
t.Fatalf("解锁失败: %v", unlockErr)
|
||||
}
|
||||
t.Fatal("重复加锁应失败,但返回 nil")
|
||||
}
|
||||
}
|
||||
@@ -43,16 +53,22 @@ func TestSingletonLock_UnlockThenRelock(t *testing.T) {
|
||||
if err := lock1.Lock(); err != nil {
|
||||
t.Fatalf("首次加锁应成功: %v", err)
|
||||
}
|
||||
lock1.Unlock()
|
||||
if err := lock1.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
|
||||
lock2 := NewSingletonLock(lockPath)
|
||||
if err := lock2.Lock(); err != nil {
|
||||
t.Fatalf("释放后重新加锁应成功: %v", err)
|
||||
}
|
||||
lock2.Unlock()
|
||||
if err := lock2.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingletonLock_UnlockWithoutLock(t *testing.T) {
|
||||
lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock"))
|
||||
lock.Unlock()
|
||||
if err := lock.Unlock(); err != nil {
|
||||
t.Fatalf("未加锁时解锁失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/embedfs"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestSetupStaticFiles(t *testing.T) {
|
||||
|
||||
@@ -44,7 +44,11 @@ func main() {
|
||||
if err != nil {
|
||||
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
|
||||
}
|
||||
defer zapLogger.Sync()
|
||||
defer func() {
|
||||
if err := zapLogger.Sync(); err != nil {
|
||||
minimalLogger.Warn("同步日志失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
cfg.PrintSummary(zapLogger)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -58,7 +59,10 @@ type LogConfig struct {
|
||||
// DefaultConfig returns default config values
|
||||
func DefaultConfig() *Config {
|
||||
// Use home dir for default paths
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
homeDir = "."
|
||||
}
|
||||
nexDir := filepath.Join(homeDir, ".nex")
|
||||
|
||||
return &Config{
|
||||
@@ -97,7 +101,7 @@ func GetConfigDir() (string, error) {
|
||||
return "", err
|
||||
}
|
||||
configDir := filepath.Join(homeDir, ".nex")
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return configDir, nil
|
||||
@@ -123,7 +127,10 @@ func GetConfigPath() (string, error) {
|
||||
|
||||
// setupDefaults 设置默认配置值
|
||||
func setupDefaults(v *viper.Viper) {
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
homeDir = "."
|
||||
}
|
||||
nexDir := filepath.Join(homeDir, ".nex")
|
||||
|
||||
v.SetDefault("server.port", 9826)
|
||||
@@ -177,27 +184,33 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
|
||||
|
||||
// 绑定所有 flag 到 viper
|
||||
// 注意:必须在设置默认值之后绑定
|
||||
v.BindPFlag("server.port", flagSet.Lookup("server-port"))
|
||||
v.BindPFlag("server.read_timeout", flagSet.Lookup("server-read-timeout"))
|
||||
v.BindPFlag("server.write_timeout", flagSet.Lookup("server-write-timeout"))
|
||||
bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
|
||||
bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
|
||||
bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
|
||||
|
||||
v.BindPFlag("database.driver", flagSet.Lookup("database-driver"))
|
||||
v.BindPFlag("database.path", flagSet.Lookup("database-path"))
|
||||
v.BindPFlag("database.host", flagSet.Lookup("database-host"))
|
||||
v.BindPFlag("database.port", flagSet.Lookup("database-port"))
|
||||
v.BindPFlag("database.user", flagSet.Lookup("database-user"))
|
||||
v.BindPFlag("database.password", flagSet.Lookup("database-password"))
|
||||
v.BindPFlag("database.dbname", flagSet.Lookup("database-dbname"))
|
||||
v.BindPFlag("database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
|
||||
v.BindPFlag("database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
|
||||
v.BindPFlag("database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
|
||||
bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
|
||||
bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
|
||||
bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
|
||||
bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
|
||||
bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
|
||||
bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
|
||||
bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
|
||||
bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
|
||||
bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
|
||||
bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
|
||||
|
||||
v.BindPFlag("log.level", flagSet.Lookup("log-level"))
|
||||
v.BindPFlag("log.path", flagSet.Lookup("log-path"))
|
||||
v.BindPFlag("log.max_size", flagSet.Lookup("log-max-size"))
|
||||
v.BindPFlag("log.max_backups", flagSet.Lookup("log-max-backups"))
|
||||
v.BindPFlag("log.max_age", flagSet.Lookup("log-max-age"))
|
||||
v.BindPFlag("log.compress", flagSet.Lookup("log-compress"))
|
||||
bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
|
||||
bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
|
||||
bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
|
||||
bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
|
||||
bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
|
||||
bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
|
||||
}
|
||||
|
||||
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
|
||||
if err := v.BindPFlag(key, flag); err != nil {
|
||||
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
|
||||
}
|
||||
}
|
||||
|
||||
// setupEnv 绑定环境变量
|
||||
@@ -218,10 +231,17 @@ func setupConfigFile(v *viper.Viper, configPath string) error {
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
// 配置文件不存在,创建默认配置文件
|
||||
if err := v.SafeWriteConfig(); err != nil {
|
||||
// 忽略写入错误(可能目录已存在等)
|
||||
writeErr := v.SafeWriteConfigAs(configPath)
|
||||
if writeErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var alreadyExistsErr viper.ConfigFileAlreadyExistsError
|
||||
if errors.As(writeErr, &alreadyExistsErr) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return appErrors.Wrap(appErrors.ErrInternal, writeErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -246,7 +266,9 @@ func LoadConfigFromPath(configPath string) (*Config, error) {
|
||||
setupFlags(v, flagSet)
|
||||
|
||||
// 3. 解析 CLI 参数(忽略错误,因为可能没有参数)
|
||||
flagSet.Parse(os.Args[1:])
|
||||
if err := flagSet.Parse(os.Args[1:]); err != nil {
|
||||
return nil, appErrors.Wrap(appErrors.ErrInvalidRequest, err)
|
||||
}
|
||||
|
||||
// 4. 获取配置文件路径(可能被 --config 参数覆盖)
|
||||
if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" {
|
||||
@@ -295,11 +317,11 @@ func SaveConfig(cfg *Config) error {
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(configPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
return os.WriteFile(configPath, data, 0600)
|
||||
return os.WriteFile(configPath, data, 0o600)
|
||||
}
|
||||
|
||||
// Validate validates the config
|
||||
|
||||
@@ -236,7 +236,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
|
||||
configPath := filepath.Join(dir, "config.yaml")
|
||||
data, err := yaml.Marshal(cfg)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(configPath, data, 0644)
|
||||
err = os.WriteFile(configPath, data, 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 加载配置
|
||||
|
||||
@@ -6,15 +6,15 @@ import (
|
||||
|
||||
// Provider 供应商模型
|
||||
type Provider struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
APIKey string `gorm:"not null" json:"api_key"`
|
||||
BaseURL string `gorm:"not null" json:"base_url"`
|
||||
Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
APIKey string `gorm:"not null" json:"api_key"`
|
||||
BaseURL string `gorm:"not null" json:"base_url"`
|
||||
Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
|
||||
}
|
||||
|
||||
// Model 模型配置(id 为 UUID 自动生成,UNIQUE(provider_id, model_name))
|
||||
@@ -47,4 +47,3 @@ func (Model) TableName() string {
|
||||
func (UsageStats) TableName() string {
|
||||
return "usage_stats"
|
||||
}
|
||||
|
||||
|
||||
@@ -141,7 +141,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
||||
Message: err.Message,
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(errMsg)
|
||||
body, marshalErr := json.Marshal(errMsg)
|
||||
if marshalErr != nil {
|
||||
return []byte(`{"type":"error","error":{"type":"internal_error","message":"internal error"}}`), statusCode
|
||||
}
|
||||
return body, statusCode
|
||||
}
|
||||
|
||||
@@ -235,7 +238,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
|
||||
return "", nil, err
|
||||
}
|
||||
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
}
|
||||
return current, rewriteFunc, nil
|
||||
@@ -269,7 +276,11 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
// Chat 响应必须有 model 字段,存在则改写,不存在则添加
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
default:
|
||||
return body, nil
|
||||
|
||||
@@ -2,6 +2,7 @@ package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
@@ -52,10 +53,10 @@ func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nativePath string
|
||||
name string
|
||||
nativePath string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected string
|
||||
expected string
|
||||
}{
|
||||
{"聊天", "/v1/messages", conversion.InterfaceTypeChat, "/v1/messages"},
|
||||
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
|
||||
@@ -102,9 +103,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
name string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected bool
|
||||
expected bool
|
||||
}{
|
||||
{"聊天", conversion.InterfaceTypeChat, true},
|
||||
{"模型", conversion.InterfaceTypeModels, true},
|
||||
@@ -141,8 +142,8 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
|
||||
t.Run("解码嵌入请求", func(t *testing.T) {
|
||||
_, err := a.DecodeEmbeddingRequest([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
@@ -150,24 +151,24 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.True(t, errors.As(err, &convErr))
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("解码嵌入响应", func(t *testing.T) {
|
||||
_, err := a.DecodeEmbeddingResponse([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码嵌入响应", func(t *testing.T) {
|
||||
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
}
|
||||
@@ -178,8 +179,8 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
|
||||
t.Run("解码重排序请求", func(t *testing.T) {
|
||||
_, err := a.DecodeRerankRequest([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
@@ -187,24 +188,24 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("解码重排序响应", func(t *testing.T) {
|
||||
_, err := a.DecodeRerankResponse([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码重排序响应", func(t *testing.T) {
|
||||
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -28,7 +28,10 @@ func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
|
||||
|
||||
var canonicalMsgs []canonical.CanonicalMessage
|
||||
for _, msg := range req.Messages {
|
||||
decoded := decodeMessage(msg)
|
||||
decoded, err := decodeMessage(msg)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析消息内容失败").WithCause(err)
|
||||
}
|
||||
canonicalMsgs = append(canonicalMsgs, decoded...)
|
||||
}
|
||||
|
||||
@@ -94,10 +97,13 @@ func decodeSystem(system any) any {
|
||||
}
|
||||
|
||||
// decodeMessage 解码 Anthropic 消息
|
||||
func decodeMessage(msg Message) []canonical.CanonicalMessage {
|
||||
func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
blocks := decodeContentBlocks(msg.Content)
|
||||
blocks, err := decodeContentBlocks(msg.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var toolResults []canonical.ContentBlock
|
||||
var others []canonical.ContentBlock
|
||||
for _, b := range blocks {
|
||||
@@ -117,58 +123,83 @@ func decodeMessage(msg Message) []canonical.CanonicalMessage {
|
||||
if len(result) == 0 {
|
||||
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
|
||||
}
|
||||
return result
|
||||
return result, nil
|
||||
|
||||
case "assistant":
|
||||
blocks := decodeContentBlocks(msg.Content)
|
||||
blocks, err := decodeContentBlocks(msg.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// decodeContentBlocks 解码内容块列表
|
||||
func decodeContentBlocks(content any) []canonical.ContentBlock {
|
||||
func decodeContentBlocks(content any) ([]canonical.ContentBlock, error) {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(v)}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(v)}, nil
|
||||
case []any:
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
block := decodeSingleContentBlock(m)
|
||||
block, err := decodeSingleContentBlock(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if block != nil {
|
||||
blocks = append(blocks, *block)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(blocks) > 0 {
|
||||
return blocks
|
||||
return blocks, nil
|
||||
}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
|
||||
case nil:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
|
||||
default:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// decodeSingleContentBlock 解码单个内容块
|
||||
func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
||||
t, _ := m["type"].(string)
|
||||
func decodeSingleContentBlock(m map[string]any) (*canonical.ContentBlock, error) {
|
||||
t, ok := m["type"].(string)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
return &canonical.ContentBlock{Type: "text", Text: text}
|
||||
text, ok := m["text"].(string)
|
||||
if !ok {
|
||||
text = ""
|
||||
}
|
||||
return &canonical.ContentBlock{Type: "text", Text: text}, nil
|
||||
case "tool_use":
|
||||
id, _ := m["id"].(string)
|
||||
name, _ := m["name"].(string)
|
||||
input, _ := json.Marshal(m["input"])
|
||||
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}
|
||||
id, ok := m["id"].(string)
|
||||
if !ok {
|
||||
id = ""
|
||||
}
|
||||
name, ok := m["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
input, err := json.Marshal(m["input"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}, nil
|
||||
case "tool_result":
|
||||
toolUseID, _ := m["tool_use_id"].(string)
|
||||
toolUseID, ok := m["tool_use_id"].(string)
|
||||
if !ok {
|
||||
toolUseID = ""
|
||||
}
|
||||
isErr := false
|
||||
if ie, ok := m["is_error"].(bool); ok {
|
||||
isErr = ie
|
||||
@@ -179,7 +210,11 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
||||
case string:
|
||||
content = json.RawMessage(fmt.Sprintf("%q", cv))
|
||||
default:
|
||||
content, _ = json.Marshal(cv)
|
||||
encoded, err := json.Marshal(cv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content = encoded
|
||||
}
|
||||
} else {
|
||||
content = json.RawMessage(`""`)
|
||||
@@ -189,15 +224,18 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
||||
ToolUseID: toolUseID,
|
||||
Content: content,
|
||||
IsError: &isErr,
|
||||
}
|
||||
}, nil
|
||||
case "thinking":
|
||||
thinking, _ := m["thinking"].(string)
|
||||
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}
|
||||
thinking, ok := m["thinking"].(string)
|
||||
if !ok {
|
||||
thinking = ""
|
||||
}
|
||||
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}, nil
|
||||
case "redacted_thinking":
|
||||
// 丢弃
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// decodeTools 解码工具定义
|
||||
@@ -232,7 +270,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
case map[string]any:
|
||||
t, _ := v["type"].(string)
|
||||
t, ok := v["type"].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch t {
|
||||
case "auto":
|
||||
return canonical.NewToolChoiceAuto()
|
||||
@@ -241,7 +282,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
case "any":
|
||||
return canonical.NewToolChoiceAny()
|
||||
case "tool":
|
||||
name, _ := v["name"].(string)
|
||||
name, ok := v["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,7 +182,7 @@ func encodeContentBlocks(blocks []canonical.ContentBlock) []map[string]any {
|
||||
result = append(result, m)
|
||||
case "tool_result":
|
||||
m := map[string]any{
|
||||
"type": "tool_result",
|
||||
"type": "tool_result",
|
||||
"tool_use_id": b.ToolUseID,
|
||||
}
|
||||
if b.Content != nil {
|
||||
@@ -335,11 +335,11 @@ func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"id": resp.ID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": resp.Model,
|
||||
"content": blocks,
|
||||
"id": resp.ID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": resp.Model,
|
||||
"content": blocks,
|
||||
"stop_reason": sr,
|
||||
"stop_sequence": nil,
|
||||
"usage": usage,
|
||||
|
||||
@@ -33,7 +33,8 @@ func TestEncodeRequest_Basic(t *testing.T) {
|
||||
assert.Equal(t, true, result["stream"])
|
||||
assert.Equal(t, float64(1024), result["max_tokens"])
|
||||
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 1)
|
||||
}
|
||||
|
||||
@@ -55,17 +56,20 @@ func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
|
||||
// tool 消息应被合并到相邻 user 消息
|
||||
foundToolResult := false
|
||||
for _, m := range msgs {
|
||||
msgMap := m.(map[string]any)
|
||||
msgMap, ok := m.(map[string]any)
|
||||
require.True(t, ok)
|
||||
if msgMap["role"] == "user" {
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if ok {
|
||||
for _, c := range content {
|
||||
block := c.(map[string]any)
|
||||
block, ok := c.(map[string]any)
|
||||
require.True(t, ok)
|
||||
if block["type"] == "tool_result" {
|
||||
foundToolResult = true
|
||||
}
|
||||
@@ -93,8 +97,10 @@ func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
firstMsg := msgs[0].(map[string]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
firstMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "user", firstMsg["role"])
|
||||
}
|
||||
|
||||
@@ -140,9 +146,11 @@ func TestEncodeResponse_Basic(t *testing.T) {
|
||||
assert.Equal(t, "assistant", result["role"])
|
||||
assert.Equal(t, "end_turn", result["stop_reason"])
|
||||
|
||||
content := result["content"].([]any)
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, content, 1)
|
||||
block := content[0].(map[string]any)
|
||||
block, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "text", block["type"])
|
||||
assert.Equal(t, "你好", block["text"])
|
||||
}
|
||||
@@ -160,10 +168,12 @@ func TestEncodeModelsResponse(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
data := result["data"].([]any)
|
||||
data, ok := result["data"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, data, 1)
|
||||
|
||||
model := data[0].(map[string]any)
|
||||
model, ok := data[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "claude-3-opus", model["id"])
|
||||
// created 应为 RFC3339 格式
|
||||
createdAt, ok := model["created_at"].(string)
|
||||
@@ -280,11 +290,14 @@ func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 1)
|
||||
userMsg := msgs[0].(map[string]any)
|
||||
userMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "user", userMsg["role"])
|
||||
content := userMsg["content"].([]any)
|
||||
content, ok := userMsg["content"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, content, 2)
|
||||
}
|
||||
|
||||
@@ -321,7 +334,8 @@ func TestEncodeResponse_ReasoningTokens(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage := result["usage"].(map[string]any)
|
||||
usage, ok := result["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
_, hasReasoning := usage["reasoning_tokens"]
|
||||
assert.False(t, hasReasoning)
|
||||
}
|
||||
@@ -341,9 +355,11 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
content := result["content"].([]any)
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, content, 1)
|
||||
block := content[0].(map[string]any)
|
||||
block, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "tool_use", block["type"])
|
||||
assert.Equal(t, "tool_1", block["id"])
|
||||
assert.Equal(t, "search", block["name"])
|
||||
|
||||
@@ -28,7 +28,7 @@ func NewStreamDecoder() *StreamDecoder {
|
||||
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
data := rawChunk
|
||||
if len(d.utf8Remainder) > 0 {
|
||||
data = append(d.utf8Remainder, rawChunk...)
|
||||
data = append(append([]byte{}, d.utf8Remainder...), rawChunk...)
|
||||
d.utf8Remainder = nil
|
||||
}
|
||||
|
||||
@@ -50,9 +50,10 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
|
||||
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
switch {
|
||||
case strings.HasPrefix(line, "event: "):
|
||||
eventType = strings.TrimPrefix(line, "event: ")
|
||||
} else if strings.HasPrefix(line, "data: ") {
|
||||
case strings.HasPrefix(line, "data: "):
|
||||
eventData = strings.TrimPrefix(line, "data: ")
|
||||
if eventType != "" && eventData != "" {
|
||||
chunkEvents := d.processEvent(eventType, []byte(eventData))
|
||||
@@ -60,8 +61,8 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
|
||||
}
|
||||
eventType = ""
|
||||
eventData = ""
|
||||
} else if line == "" {
|
||||
// SSE 事件分隔符
|
||||
case line == "":
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,7 +136,7 @@ func (d *StreamDecoder) processMessageStart(data []byte) []canonical.CanonicalSt
|
||||
// processContentBlockStart 处理内容块开始事件
|
||||
func (d *StreamDecoder) processContentBlockStart(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Index int `json:"index"`
|
||||
Index int `json:"index"`
|
||||
ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
|
||||
@@ -47,23 +47,23 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
|
||||
checkValue string
|
||||
}{
|
||||
{
|
||||
name: "text_delta",
|
||||
deltaType: "text_delta",
|
||||
deltaData: map[string]any{"type": "text_delta", "text": "你好"},
|
||||
name: "text_delta",
|
||||
deltaType: "text_delta",
|
||||
deltaData: map[string]any{"type": "text_delta", "text": "你好"},
|
||||
checkField: "text",
|
||||
checkValue: "你好",
|
||||
},
|
||||
{
|
||||
name: "input_json_delta",
|
||||
deltaType: "input_json_delta",
|
||||
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
|
||||
name: "input_json_delta",
|
||||
deltaType: "input_json_delta",
|
||||
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
|
||||
checkField: "partial_json",
|
||||
checkValue: "{\"key\":",
|
||||
},
|
||||
{
|
||||
name: "thinking_delta",
|
||||
deltaType: "thinking_delta",
|
||||
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
|
||||
name: "thinking_delta",
|
||||
deltaType: "thinking_delta",
|
||||
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
|
||||
checkField: "thinking",
|
||||
checkValue: "思考中",
|
||||
},
|
||||
@@ -74,7 +74,7 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": tt.deltaData,
|
||||
"delta": tt.deltaData,
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_delta", payload)
|
||||
|
||||
@@ -298,7 +298,7 @@ func TestStreamDecoder_WebSearchToolResult_Suppressed(t *testing.T) {
|
||||
"type": "content_block_start",
|
||||
"index": 3,
|
||||
"content_block": map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "search_1",
|
||||
},
|
||||
}
|
||||
@@ -331,8 +331,8 @@ func TestStreamDecoder_CitationsDelta_Discarded(t *testing.T) {
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"type": "citations_delta",
|
||||
"citation": map[string]any{"title": "ref1"},
|
||||
"type": "citations_delta",
|
||||
"citation": map[string]any{"title": "ref1"},
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_delta", payload)
|
||||
@@ -466,7 +466,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
|
||||
},
|
||||
}
|
||||
deltaPayload1 := map[string]any{
|
||||
"type": "message_delta",
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{"stop_reason": "end_turn"},
|
||||
"usage": map[string]any{"output_tokens": 25},
|
||||
}
|
||||
@@ -478,7 +478,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
|
||||
assert.Equal(t, 25, events[0].Usage.OutputTokens)
|
||||
|
||||
deltaPayload2 := map[string]any{
|
||||
"type": "message_delta",
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{"stop_reason": "end_turn"},
|
||||
"usage": map[string]any{"output_tokens": 30},
|
||||
}
|
||||
|
||||
@@ -80,7 +80,8 @@ func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
cb, ok := payload["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "text", cb["type"])
|
||||
}
|
||||
|
||||
@@ -107,7 +108,8 @@ func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
cb, ok := payload["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "tool_use", cb["type"])
|
||||
assert.Equal(t, "toolu_1", cb["id"])
|
||||
assert.Equal(t, "search", cb["name"])
|
||||
@@ -131,7 +133,8 @@ func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
cb, ok := payload["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "thinking", cb["type"])
|
||||
}
|
||||
|
||||
@@ -173,7 +176,8 @@ func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
delta := payload["delta"].(map[string]any)
|
||||
delta, okd := payload["delta"].(map[string]any)
|
||||
require.True(t, okd)
|
||||
assert.Equal(t, "end_turn", delta["stop_reason"])
|
||||
}
|
||||
|
||||
@@ -199,7 +203,8 @@ func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
u := payload["usage"].(map[string]any)
|
||||
u, oku := payload["usage"].(map[string]any)
|
||||
require.True(t, oku)
|
||||
assert.Equal(t, float64(88), u["output_tokens"])
|
||||
}
|
||||
|
||||
|
||||
@@ -173,13 +173,15 @@ func TestDecodeMessage_UserWithOnlyToolResults(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDecodeContentBlocks_Nil(t *testing.T) {
|
||||
blocks := decodeContentBlocks(nil)
|
||||
blocks, err := decodeContentBlocks(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, blocks, 1)
|
||||
assert.Equal(t, "", blocks[0].Text)
|
||||
}
|
||||
|
||||
func TestDecodeContentBlocks_String(t *testing.T) {
|
||||
blocks := decodeContentBlocks("hello")
|
||||
blocks, err := decodeContentBlocks("hello")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, blocks, 1)
|
||||
assert.Equal(t, "hello", blocks[0].Text)
|
||||
}
|
||||
@@ -217,8 +219,10 @@ func TestEncodeToolChoice(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := encodeToolChoice(tt.choice)
|
||||
assert.Equal(t, tt.want["type"], result.(map[string]any)["type"])
|
||||
assert.Equal(t, tt.want["name"], result.(map[string]any)["name"])
|
||||
r, ok := result.(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.want["type"], r["type"])
|
||||
assert.Equal(t, tt.want["name"], r["name"])
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -315,12 +319,15 @@ func TestEncodeRequest_WithTools(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
tools := result["tools"].([]any)
|
||||
tools, okt := result["tools"].([]any)
|
||||
require.True(t, okt)
|
||||
assert.Len(t, tools, 1)
|
||||
tool := tools[0].(map[string]any)
|
||||
tool, okt2 := tools[0].(map[string]any)
|
||||
require.True(t, okt2)
|
||||
assert.Equal(t, "search", tool["name"])
|
||||
assert.Equal(t, "Search things", tool["description"])
|
||||
tc := result["tool_choice"].(map[string]any)
|
||||
tc, oktc := result["tool_choice"].(map[string]any)
|
||||
require.True(t, oktc)
|
||||
assert.Equal(t, "auto", tc["type"])
|
||||
}
|
||||
|
||||
@@ -354,9 +361,9 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheReadTokens: &cacheRead,
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheReadTokens: &cacheRead,
|
||||
CacheCreationTokens: &cacheCreation,
|
||||
},
|
||||
}
|
||||
@@ -366,7 +373,8 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage := result["usage"].(map[string]any)
|
||||
usage, oku := result["usage"].(map[string]any)
|
||||
require.True(t, oku)
|
||||
assert.Equal(t, float64(100), usage["input_tokens"])
|
||||
assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
|
||||
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])
|
||||
|
||||
@@ -6,22 +6,22 @@ import (
|
||||
|
||||
// MessagesRequest Anthropic Messages 请求
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Metadata *RequestMetadata `json:"metadata,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
OutputConfig *OutputConfig `json:"output_config,omitempty"`
|
||||
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
Container any `json:"container,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Metadata *RequestMetadata `json:"metadata,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
OutputConfig *OutputConfig `json:"output_config,omitempty"`
|
||||
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
Container any `json:"container,omitempty"`
|
||||
}
|
||||
|
||||
// RequestMetadata 请求元数据
|
||||
@@ -122,8 +122,8 @@ type ContentBlock struct {
|
||||
|
||||
// ResponseUsage 响应用量
|
||||
type ResponseUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"`
|
||||
CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"`
|
||||
}
|
||||
|
||||
@@ -38,8 +38,8 @@ type CanonicalEmbeddingResponse struct {
|
||||
|
||||
// EmbeddingData 嵌入数据项
|
||||
type EmbeddingData struct {
|
||||
Index int `json:"index"`
|
||||
Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串
|
||||
Index int `json:"index"`
|
||||
Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串
|
||||
}
|
||||
|
||||
// EmbeddingUsage 嵌入用量
|
||||
|
||||
@@ -18,17 +18,17 @@ const (
|
||||
type DeltaType string
|
||||
|
||||
const (
|
||||
DeltaTypeText DeltaType = "text_delta"
|
||||
DeltaTypeInputJSON DeltaType = "input_json_delta"
|
||||
DeltaTypeThinking DeltaType = "thinking_delta"
|
||||
DeltaTypeText DeltaType = "text_delta"
|
||||
DeltaTypeInputJSON DeltaType = "input_json_delta"
|
||||
DeltaTypeThinking DeltaType = "thinking_delta"
|
||||
)
|
||||
|
||||
// StreamDelta 流式增量联合体
|
||||
type StreamDelta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// StreamContentBlock 流式内容块联合体
|
||||
@@ -48,12 +48,12 @@ type CanonicalStreamEvent struct {
|
||||
Message *StreamMessage `json:"message,omitempty"`
|
||||
|
||||
// ContentBlockStartEvent / ContentBlockDeltaEvent / ContentBlockStopEvent
|
||||
Index *int `json:"index,omitempty"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *StreamContentBlock `json:"content_block,omitempty"`
|
||||
Delta *StreamDelta `json:"delta,omitempty"`
|
||||
Delta *StreamDelta `json:"delta,omitempty"`
|
||||
|
||||
// MessageDeltaEvent
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
Usage *CanonicalUsage `json:"usage,omitempty"`
|
||||
|
||||
// ErrorEvent
|
||||
|
||||
@@ -40,8 +40,8 @@ type ContentBlock struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// ToolUseBlock
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
|
||||
// ToolResultBlock
|
||||
@@ -138,43 +138,43 @@ type ThinkingConfig struct {
|
||||
|
||||
// OutputFormat 输出格式联合体
|
||||
type OutputFormat struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Schema json.RawMessage `json:"schema,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Schema json.RawMessage `json:"schema,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalRequest 规范请求
|
||||
type CanonicalRequest struct {
|
||||
Model string `json:"model"`
|
||||
System any `json:"system,omitempty"` // nil, string, or []SystemBlock
|
||||
Model string `json:"model"`
|
||||
System any `json:"system,omitempty"` // nil, string, or []SystemBlock
|
||||
Messages []CanonicalMessage `json:"messages"`
|
||||
Tools []CanonicalTool `json:"tools,omitempty"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Parameters RequestParameters `json:"parameters"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
OutputFormat *OutputFormat `json:"output_format,omitempty"`
|
||||
ParallelToolUse *bool `json:"parallel_tool_use,omitempty"`
|
||||
Tools []CanonicalTool `json:"tools,omitempty"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Parameters RequestParameters `json:"parameters"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
OutputFormat *OutputFormat `json:"output_format,omitempty"`
|
||||
ParallelToolUse *bool `json:"parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalUsage 规范用量
|
||||
type CanonicalUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheReadTokens *int `json:"cache_read_tokens,omitempty"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheReadTokens *int `json:"cache_read_tokens,omitempty"`
|
||||
CacheCreationTokens *int `json:"cache_creation_tokens,omitempty"`
|
||||
ReasoningTokens *int `json:"reasoning_tokens,omitempty"`
|
||||
ReasoningTokens *int `json:"reasoning_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalResponse 规范响应
|
||||
type CanonicalResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
Usage CanonicalUsage `json:"usage"`
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
Usage CanonicalUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// GetSystemString 获取系统消息字符串
|
||||
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
|
||||
func TestGetSystemString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
system any
|
||||
want string
|
||||
name string
|
||||
system any
|
||||
want string
|
||||
}{
|
||||
{"string", "hello", "hello"},
|
||||
{"nil", nil, ""},
|
||||
@@ -97,11 +97,11 @@ func TestCanonicalRequest_RoundTrip(t *testing.T) {
|
||||
func TestCanonicalResponse_RoundTrip(t *testing.T) {
|
||||
sr := StopReasonEndTurn
|
||||
resp := &CanonicalResponse{
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []ContentBlock{NewTextBlock("hello")},
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []ContentBlock{NewTextBlock("hello")},
|
||||
StopReason: &sr,
|
||||
Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
|
||||
@@ -114,7 +114,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
||||
}
|
||||
|
||||
interfaceType := clientAdapter.DetectInterfaceType(nativePath)
|
||||
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||
providerHeaders := providerAdapter.BuildHeaders(provider)
|
||||
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
|
||||
if err != nil {
|
||||
@@ -122,7 +122,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
||||
}
|
||||
|
||||
return &HTTPRequestSpec{
|
||||
URL: provider.BaseURL + providerUrl,
|
||||
URL: provider.BaseURL + providerURL,
|
||||
Method: spec.Method,
|
||||
Headers: providerHeaders,
|
||||
Body: providerBody,
|
||||
@@ -134,24 +134,21 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
// Smart Passthrough: 同协议时最小化改写 model 字段
|
||||
if modelOverride != "" && len(spec.Body) > 0 {
|
||||
adapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return &spec, nil
|
||||
}
|
||||
|
||||
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
|
||||
if err != nil {
|
||||
adapter, getErr := e.registry.Get(clientProtocol)
|
||||
if getErr == nil {
|
||||
rewrittenBody, rewriteErr := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
|
||||
if rewriteErr != nil {
|
||||
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
|
||||
zap.Error(err),
|
||||
zap.Error(rewriteErr),
|
||||
zap.String("interface", string(interfaceType)))
|
||||
return &spec, nil
|
||||
} else {
|
||||
return &HTTPResponseSpec{
|
||||
StatusCode: spec.StatusCode,
|
||||
Headers: spec.Headers,
|
||||
Body: rewrittenBody,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return &HTTPResponseSpec{
|
||||
StatusCode: spec.StatusCode,
|
||||
Headers: spec.Headers,
|
||||
Body: rewrittenBody,
|
||||
}, nil
|
||||
}
|
||||
return &spec, nil
|
||||
}
|
||||
@@ -182,11 +179,10 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
|
||||
if modelOverride != "" {
|
||||
adapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return NewPassthroughStreamConverter(), nil
|
||||
adapter, getErr := e.registry.Get(clientProtocol)
|
||||
if getErr == nil {
|
||||
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
|
||||
}
|
||||
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
|
||||
}
|
||||
return NewPassthroughStreamConverter(), nil
|
||||
}
|
||||
@@ -201,9 +197,9 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
|
||||
}
|
||||
|
||||
ctx := ConversionContext{
|
||||
ConversionID: uuid.New().String(),
|
||||
InterfaceType: InterfaceTypeChat,
|
||||
Timestamp: time.Now(),
|
||||
ConversionID: uuid.New().String(),
|
||||
InterfaceType: InterfaceTypeChat,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
return NewCanonicalStreamConverterWithMiddleware(
|
||||
@@ -306,7 +302,7 @@ func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapte
|
||||
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
models, err := providerAdapter.DecodeModelsResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeModelsResponse(models)
|
||||
@@ -320,12 +316,12 @@ func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdap
|
||||
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
info, err := providerAdapter.DecodeModelInfoResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeModelInfoResponse(info)
|
||||
if err != nil {
|
||||
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return encoded, nil
|
||||
@@ -334,7 +330,7 @@ func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerA
|
||||
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
req, err := clientAdapter.DecodeEmbeddingRequest(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return providerAdapter.EncodeEmbeddingRequest(req, provider)
|
||||
@@ -343,7 +339,7 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P
|
||||
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
if modelOverride != "" {
|
||||
@@ -355,21 +351,22 @@ func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerA
|
||||
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
req, err := clientAdapter.DecodeRerankRequest(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return providerAdapter.EncodeRerankRequest(req, provider)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
resp, err := providerAdapter.DecodeRerankResponse(body)
|
||||
if err != nil {
|
||||
return body, nil
|
||||
resp, decodeErr := providerAdapter.DecodeRerankResponse(body)
|
||||
if decodeErr == nil {
|
||||
if modelOverride != "" {
|
||||
resp.Model = modelOverride
|
||||
}
|
||||
return clientAdapter.EncodeRerankResponse(resp)
|
||||
}
|
||||
if modelOverride != "" {
|
||||
resp.Model = modelOverride
|
||||
}
|
||||
return clientAdapter.EncodeRerankResponse(resp)
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// DetectInterfaceType 检测接口类型
|
||||
@@ -391,8 +388,12 @@ func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol stri
|
||||
"type": "internal_error",
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(fallback)
|
||||
return body, 500, nil
|
||||
body, marshalErr := json.Marshal(fallback)
|
||||
if marshalErr == nil {
|
||||
return body, 500, nil
|
||||
}
|
||||
|
||||
return []byte(`{"error":{"message":"internal error","type":"internal_error"}}`), 500, nil
|
||||
}
|
||||
body, statusCode := adapter.EncodeError(err)
|
||||
return body, statusCode, nil
|
||||
|
||||
@@ -38,8 +38,8 @@ func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName }
|
||||
func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" }
|
||||
func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName }
|
||||
func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" }
|
||||
func (m *mockProtocolAdapter) SupportsPassthrough() bool { return m.passthrough }
|
||||
|
||||
func (m *mockProtocolAdapter) DetectInterfaceType(nativePath string) InterfaceType {
|
||||
@@ -190,14 +190,16 @@ func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel str
|
||||
// noopStreamDecoder 空流式解码器
|
||||
type noopStreamDecoder struct{}
|
||||
|
||||
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil }
|
||||
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
|
||||
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
return nil
|
||||
}
|
||||
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
|
||||
|
||||
// noopStreamEncoder 空流式编码器
|
||||
type noopStreamEncoder struct{}
|
||||
|
||||
func (e *noopStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { return nil }
|
||||
func (e *noopStreamEncoder) Flush() [][]byte { return nil }
|
||||
func (e *noopStreamEncoder) Flush() [][]byte { return nil }
|
||||
|
||||
// ============ 测试用例 ============
|
||||
|
||||
@@ -615,6 +617,7 @@ func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.Canonical
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||
if d.flushFn != nil {
|
||||
return d.flushFn()
|
||||
@@ -634,6 +637,7 @@ func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEve
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *engineTestStreamEncoder) Flush() [][]byte {
|
||||
if e.flushFn != nil {
|
||||
return e.flushFn()
|
||||
|
||||
@@ -6,17 +6,17 @@ import "fmt"
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
|
||||
ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD"
|
||||
ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE"
|
||||
ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE"
|
||||
ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR"
|
||||
ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR"
|
||||
ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR"
|
||||
ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR"
|
||||
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
|
||||
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
|
||||
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
|
||||
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
|
||||
ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD"
|
||||
ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE"
|
||||
ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE"
|
||||
ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR"
|
||||
ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR"
|
||||
ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR"
|
||||
ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR"
|
||||
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
|
||||
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
|
||||
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
|
||||
)
|
||||
|
||||
// ConversionError 协议转换错误
|
||||
|
||||
@@ -4,10 +4,10 @@ package conversion
|
||||
type InterfaceType string
|
||||
|
||||
const (
|
||||
InterfaceTypeChat InterfaceType = "CHAT"
|
||||
InterfaceTypeModels InterfaceType = "MODELS"
|
||||
InterfaceTypeModelInfo InterfaceType = "MODEL_INFO"
|
||||
InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS"
|
||||
InterfaceTypeRerank InterfaceType = "RERANK"
|
||||
InterfaceTypeChat InterfaceType = "CHAT"
|
||||
InterfaceTypeModels InterfaceType = "MODELS"
|
||||
InterfaceTypeModelInfo InterfaceType = "MODEL_INFO"
|
||||
InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS"
|
||||
InterfaceTypeRerank InterfaceType = "RERANK"
|
||||
InterfaceTypePassthrough InterfaceType = "PASSTHROUGH"
|
||||
)
|
||||
|
||||
@@ -138,7 +138,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
||||
Code: string(err.Code),
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(errMsg)
|
||||
body, marshalErr := json.Marshal(errMsg)
|
||||
if marshalErr != nil {
|
||||
return []byte(`{"error":{"message":"internal error","type":"internal_error","code":"INTERNAL_ERROR"}}`), statusCode
|
||||
}
|
||||
return body, statusCode
|
||||
}
|
||||
|
||||
@@ -248,7 +251,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
|
||||
return "", nil, err
|
||||
}
|
||||
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
}
|
||||
return current, rewriteFunc, nil
|
||||
@@ -282,12 +289,20 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
|
||||
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
case conversion.InterfaceTypeRerank:
|
||||
// Rerank 响应:存在 model 字段则改写,不存在则不添加
|
||||
if _, exists := m["model"]; exists {
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
}
|
||||
return json.Marshal(m)
|
||||
default:
|
||||
|
||||
@@ -48,10 +48,10 @@ func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nativePath string
|
||||
name string
|
||||
nativePath string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected string
|
||||
expected string
|
||||
}{
|
||||
{"聊天", "/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
|
||||
{"模型", "/models", conversion.InterfaceTypeModels, "/models"},
|
||||
@@ -92,9 +92,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
name string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected bool
|
||||
expected bool
|
||||
}{
|
||||
{"聊天", conversion.InterfaceTypeChat, true},
|
||||
{"模型", conversion.InterfaceTypeModels, true},
|
||||
|
||||
@@ -215,10 +215,16 @@ func decodeUserContent(content any) []canonical.ContentBlock {
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
t, _ := m["type"].(string)
|
||||
t, ok := m["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
text, ok := m["text"].(string)
|
||||
if !ok {
|
||||
text = ""
|
||||
}
|
||||
blocks = append(blocks, canonical.NewTextBlock(text))
|
||||
case "image_url":
|
||||
blocks = append(blocks, canonical.ContentBlock{Type: "image"})
|
||||
@@ -242,9 +248,9 @@ func decodeUserContent(content any) []canonical.ContentBlock {
|
||||
|
||||
// contentPart 内容部分
|
||||
type contentPart struct {
|
||||
Type string
|
||||
Text string
|
||||
Refusal string
|
||||
Type string
|
||||
Text string
|
||||
Refusal string
|
||||
}
|
||||
|
||||
// decodeContentParts 解码内容部分
|
||||
@@ -256,13 +262,22 @@ func decodeContentParts(content any) []contentPart {
|
||||
var result []contentPart
|
||||
for _, item := range parts {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
t, _ := m["type"].(string)
|
||||
t, ok := m["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
text, ok := m["text"].(string)
|
||||
if !ok {
|
||||
text = ""
|
||||
}
|
||||
result = append(result, contentPart{Type: "text", Text: text})
|
||||
case "refusal":
|
||||
refusal, _ := m["refusal"].(string)
|
||||
refusal, ok := m["refusal"].(string)
|
||||
if !ok {
|
||||
refusal = ""
|
||||
}
|
||||
result = append(result, contentPart{Type: "refusal", Refusal: refusal})
|
||||
}
|
||||
}
|
||||
@@ -307,21 +322,33 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
case map[string]any:
|
||||
t, _ := v["type"].(string)
|
||||
t, ok := v["type"].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch t {
|
||||
case "function":
|
||||
if fn, ok := v["function"].(map[string]any); ok {
|
||||
name, _ := fn["name"].(string)
|
||||
name, ok := fn["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
case "custom":
|
||||
if custom, ok := v["custom"].(map[string]any); ok {
|
||||
name, _ := custom["name"].(string)
|
||||
name, ok := custom["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
case "allowed_tools":
|
||||
if at, ok := v["allowed_tools"].(map[string]any); ok {
|
||||
mode, _ := at["mode"].(string)
|
||||
mode, ok := at["mode"].(string)
|
||||
if !ok {
|
||||
mode = ""
|
||||
}
|
||||
if mode == "required" {
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
@@ -443,7 +470,7 @@ func decodeDeprecatedFields(req *ChatCompletionRequest) {
|
||||
case map[string]any:
|
||||
if name, ok := v["name"].(string); ok {
|
||||
req.ToolChoice = map[string]any{
|
||||
"type": "function",
|
||||
"type": "function",
|
||||
"function": map[string]any{"name": name},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -450,7 +450,7 @@ func encodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": resp.Model,
|
||||
"usage": resp.Usage,
|
||||
"usage": resp.Usage,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -45,9 +45,11 @@ func TestEncodeRequest_SystemInjection(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 2)
|
||||
firstMsg := msgs[0].(map[string]any)
|
||||
firstMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "system", firstMsg["role"])
|
||||
assert.Equal(t, "你是助手", firstMsg["content"])
|
||||
}
|
||||
@@ -72,12 +74,15 @@ func TestEncodeRequest_ToolCalls(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
assistantMsg := msgs[0].(map[string]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assistantMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
toolCalls, ok := assistantMsg["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, toolCalls, 1)
|
||||
tc := toolCalls[0].(map[string]any)
|
||||
tc, ok := toolCalls[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "call_1", tc["id"])
|
||||
}
|
||||
|
||||
@@ -100,11 +105,11 @@ func TestEncodeRequest_Thinking(t *testing.T) {
|
||||
func TestEncodeResponse_Basic(t *testing.T) {
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
@@ -115,9 +120,12 @@ func TestEncodeResponse_Basic(t *testing.T) {
|
||||
assert.Equal(t, "resp-1", result["id"])
|
||||
assert.Equal(t, "chat.completion", result["object"])
|
||||
|
||||
choices := result["choices"].([]any)
|
||||
choice := choices[0].(map[string]any)
|
||||
msg := choice["message"].(map[string]any)
|
||||
choices, ok := result["choices"].([]any)
|
||||
require.True(t, ok)
|
||||
choice, ok := choices[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
msg, ok := choice["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "你好", msg["content"])
|
||||
assert.Equal(t, "stop", choice["finish_reason"])
|
||||
}
|
||||
@@ -126,9 +134,9 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
sr := canonical.StopReasonToolUse
|
||||
input := json.RawMessage(`{"q":"test"}`)
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "resp-2",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
|
||||
ID: "resp-2",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
@@ -137,8 +145,12 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
msg := choices[0].(map[string]any)["message"].(map[string]any)
|
||||
choices, okc := result["choices"].([]any)
|
||||
require.True(t, okc)
|
||||
msgMap, okm := choices[0].(map[string]any)
|
||||
require.True(t, okm)
|
||||
msg, okmsg := msgMap["message"].(map[string]any)
|
||||
require.True(t, okmsg)
|
||||
tcs, ok := msg["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, tcs, 1)
|
||||
@@ -158,7 +170,8 @@ func TestEncodeModelsResponse(t *testing.T) {
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "list", result["object"])
|
||||
data := result["data"].([]any)
|
||||
data, okd := result["data"].([]any)
|
||||
require.True(t, okd)
|
||||
assert.Len(t, data, 2)
|
||||
}
|
||||
|
||||
@@ -317,8 +330,12 @@ func TestEncodeResponse_Thinking(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
msg := choices[0].(map[string]any)["message"].(map[string]any)
|
||||
choices, okch := result["choices"].([]any)
|
||||
require.True(t, okch)
|
||||
msgMap, okmm := choices[0].(map[string]any)
|
||||
require.True(t, okmm)
|
||||
msg, okmsg := msgMap["message"].(map[string]any)
|
||||
require.True(t, okmsg)
|
||||
assert.Equal(t, "回答", msg["content"])
|
||||
assert.Equal(t, "思考过程", msg["reasoning_content"])
|
||||
}
|
||||
|
||||
@@ -18,9 +18,9 @@ func TestStreamDecoder_BasicText(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"object": "chat.completion.chunk",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"object": "chat.completion.chunk",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -56,8 +56,8 @@ func TestStreamDecoder_ToolCalls(t *testing.T) {
|
||||
|
||||
idx := 0
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -98,8 +98,8 @@ func TestStreamDecoder_Thinking(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -127,8 +127,8 @@ func TestStreamDecoder_FinishReason(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -161,8 +161,8 @@ func TestStreamDecoder_DoneSignal(t *testing.T) {
|
||||
|
||||
// 先发送一个文本 chunk
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -190,8 +190,8 @@ func TestStreamDecoder_RefusalReuse(t *testing.T) {
|
||||
// 连续两个 refusal delta chunk
|
||||
for _, text := range []string{"拒绝", "原因"} {
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -250,8 +250,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
|
||||
|
||||
idx0 := 0
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -274,8 +274,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
|
||||
|
||||
idx1 := 1
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -322,8 +322,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -332,8 +332,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
|
||||
},
|
||||
}
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -358,8 +358,8 @@ func TestStreamDecoder_UTF8Truncation(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-utf8",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-utf8",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -390,8 +390,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
|
||||
|
||||
idx := 0
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-tc",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-tc",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -412,8 +412,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
|
||||
},
|
||||
}
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-tc",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-tc",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
|
||||
// StreamEncoder OpenAI 流式编码器
|
||||
type StreamEncoder struct {
|
||||
bufferedStart *canonical.CanonicalStreamEvent
|
||||
toolCallIndexMap map[string]int
|
||||
nextToolCallIndex int
|
||||
bufferedStart *canonical.CanonicalStreamEvent
|
||||
toolCallIndexMap map[string]int
|
||||
nextToolCallIndex int
|
||||
}
|
||||
|
||||
// NewStreamEncoder 创建 OpenAI 流式编码器
|
||||
@@ -195,8 +195,8 @@ func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent)
|
||||
func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte {
|
||||
chunk := map[string]any{
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
}},
|
||||
}
|
||||
return e.marshalChunk(chunk)
|
||||
|
||||
@@ -27,8 +27,12 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
|
||||
data := strings.TrimPrefix(s, "data: ")
|
||||
data = strings.TrimRight(data, "\n")
|
||||
require.NoError(t, json.Unmarshal([]byte(data), &payload))
|
||||
choices := payload["choices"].([]any)
|
||||
delta := choices[0].(map[string]any)["delta"].(map[string]any)
|
||||
choices, okch := payload["choices"].([]any)
|
||||
require.True(t, okch)
|
||||
msgMap, okmm := choices[0].(map[string]any)
|
||||
require.True(t, okmm)
|
||||
delta, okd := msgMap["delta"].(map[string]any)
|
||||
require.True(t, okd)
|
||||
assert.Equal(t, "assistant", delta["role"])
|
||||
}
|
||||
|
||||
|
||||
@@ -177,7 +177,8 @@ func TestEncodeRerankResponse(t *testing.T) {
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "rerank-1", result["model"])
|
||||
results := result["results"].([]any)
|
||||
results, okr := result["results"].([]any)
|
||||
require.True(t, okr)
|
||||
assert.Len(t, results, 1)
|
||||
}
|
||||
|
||||
@@ -356,9 +357,9 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
|
||||
reasoning := 20
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "r1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
|
||||
ID: "r1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{
|
||||
InputTokens: 100,
|
||||
@@ -373,7 +374,8 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage := result["usage"].(map[string]any)
|
||||
usage, oku := result["usage"].(map[string]any)
|
||||
require.True(t, oku)
|
||||
assert.Equal(t, float64(100), usage["prompt_tokens"])
|
||||
ptd, ok := usage["prompt_tokens_details"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
@@ -412,8 +414,10 @@ func TestEncodeResponse_StopReasons(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
choice := choices[0].(map[string]any)
|
||||
choices, okch := result["choices"].([]any)
|
||||
require.True(t, okch)
|
||||
choice, okc := choices[0].(map[string]any)
|
||||
require.True(t, okc)
|
||||
assert.Equal(t, tt.want, choice["finish_reason"])
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,42 +4,42 @@ import "encoding/json"
|
||||
|
||||
// ChatCompletionRequest OpenAI Chat Completion 请求
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
N *int `json:"n,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
Logprobs *bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
N *int `json:"n,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
Logprobs *bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
||||
|
||||
// 已废弃字段
|
||||
Functions []FunctionDef `json:"functions,omitempty"`
|
||||
FunctionCall any `json:"function_call,omitempty"`
|
||||
Functions []FunctionDef `json:"functions,omitempty"`
|
||||
FunctionCall any `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// Message OpenAI 消息
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
|
||||
// 已废弃
|
||||
FunctionCall *FunctionCallMsg `json:"function_call,omitempty"`
|
||||
@@ -88,8 +88,8 @@ type FunctionDef struct {
|
||||
|
||||
// ResponseFormat OpenAI 响应格式
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
|
||||
Type string `json:"type"`
|
||||
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
// JSONSchemaDef JSON Schema 定义
|
||||
@@ -118,7 +118,7 @@ type ChatCompletionResponse struct {
|
||||
|
||||
// Choice OpenAI 选择项
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
Index int `json:"index"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Delta *Message `json:"delta,omitempty"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
@@ -127,10 +127,10 @@ type Choice struct {
|
||||
|
||||
// Usage OpenAI 用量
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
|
||||
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
|
||||
func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
|
||||
moduleLogger := pkglogger.WithModule(zapLogger, "database")
|
||||
|
||||
|
||||
db, err := initDB(cfg, moduleLogger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("初始化数据库失败: %w", err)
|
||||
@@ -61,7 +61,7 @@ func initDB(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error)
|
||||
return gorm.Open(mysql.Open(dsn), gormConfig)
|
||||
default:
|
||||
dbDir := filepath.Dir(cfg.Path)
|
||||
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
||||
if err := os.MkdirAll(dbDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
|
||||
}
|
||||
if zapLogger != nil {
|
||||
@@ -95,7 +95,9 @@ func runMigrations(db *gorm.DB, driver string, zapLogger *zap.Logger) error {
|
||||
zap.String("dir", migrationsSubDir))
|
||||
}
|
||||
|
||||
goose.SetDialect(gooseDialect)
|
||||
if err := goose.SetDialect(gooseDialect); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := goose.Up(sqlDB, migrationsDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
)
|
||||
|
||||
func TestInit_SQLite(t *testing.T) {
|
||||
|
||||
@@ -13,4 +13,3 @@ type Provider struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
|
||||
@@ -6,13 +6,13 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/tests/mocks"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/tests/mocks"
|
||||
)
|
||||
|
||||
func TestProviderHandler_CreateProvider_Success(t *testing.T) {
|
||||
@@ -24,9 +24,9 @@ func TestProviderHandler_CreateProvider_Success(t *testing.T) {
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
"name": "Test",
|
||||
"api_key": "sk-test",
|
||||
"id": "p1",
|
||||
"name": "Test",
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://api.test.com",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -9,23 +9,22 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/tests/mocks"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
"nex/backend/tests/mocks"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
|
||||
|
||||
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -20,7 +20,6 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||
if id, ok := requestID.(string); ok {
|
||||
requestIDStr = id
|
||||
}
|
||||
|
||||
logger.Info("请求开始",
|
||||
pkglogger.Method(c.Request.Method),
|
||||
pkglogger.Path(path),
|
||||
|
||||
@@ -4,13 +4,13 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ModelHandler 模型管理处理器
|
||||
@@ -58,16 +58,16 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
|
||||
|
||||
err := h.modelService.Create(model)
|
||||
if err != nil {
|
||||
if err == appErrors.ErrProviderNotFound {
|
||||
if errors.Is(err, appErrors.ErrProviderNotFound) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err == appErrors.ErrDuplicateModel {
|
||||
if errors.Is(err, appErrors.ErrDuplicateModel) {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "同一供应商下模型名称已存在",
|
||||
"code": appErrors.ErrDuplicateModel.Code,
|
||||
"error": "同一供应商下模型名称已存在",
|
||||
"code": appErrors.ErrDuplicateModel.Code,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -101,7 +101,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
|
||||
|
||||
model, err := h.modelService.Get(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
@@ -166,7 +166,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
|
||||
|
||||
err := h.modelService.Delete(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
|
||||
@@ -4,13 +4,13 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ProviderHandler 供应商管理处理器
|
||||
@@ -55,7 +55,7 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
|
||||
err := h.providerService.Create(provider)
|
||||
if err != nil {
|
||||
if err == appErrors.ErrInvalidProviderID {
|
||||
if errors.Is(err, appErrors.ErrInvalidProviderID) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": appErrors.ErrInvalidProviderID.Message,
|
||||
"code": appErrors.ErrInvalidProviderID.Code,
|
||||
@@ -86,7 +86,7 @@ func (h *ProviderHandler) GetProvider(c *gin.Context) {
|
||||
|
||||
provider, err := h.providerService.Get(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
@@ -113,7 +113,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
||||
|
||||
err := h.providerService.Update(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
@@ -145,7 +145,7 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
|
||||
|
||||
err := h.providerService.Delete(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
|
||||
@@ -3,30 +3,32 @@ package handler
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/service"
|
||||
"nex/backend/pkg/modelid"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// ProxyHandler 统一代理处理器
|
||||
type ProxyHandler struct {
|
||||
engine *conversion.ConversionEngine
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
engine *conversion.ConversionEngine
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
providerService service.ProviderService
|
||||
statsService service.StatsService
|
||||
logger *zap.Logger
|
||||
statsService service.StatsService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewProxyHandler 创建统一代理处理器
|
||||
@@ -138,7 +140,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
targetProvider := conversion.NewTargetProvider(
|
||||
routeResult.Provider.BaseURL,
|
||||
routeResult.Provider.APIKey,
|
||||
routeResult.Model.ModelName, // 上游模型名,用于请求改写
|
||||
routeResult.Model.ModelName, // 上游模型名,用于请求改写
|
||||
)
|
||||
|
||||
// 判断是否流式
|
||||
@@ -159,7 +161,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
|
||||
// 转换请求
|
||||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||
if err != nil {
|
||||
h.logger.Error("转换请求失败", zap.String("error", err.Error()))
|
||||
h.logger.Error("转换请求失败", zap.Error(err))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
@@ -167,7 +169,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
|
||||
// 发送请求
|
||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.logger.Error("发送请求失败", zap.String("error", err.Error()))
|
||||
h.logger.Error("发送请求失败", zap.Error(err))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
@@ -175,7 +177,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
|
||||
// 转换响应,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
|
||||
if err != nil {
|
||||
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
|
||||
h.logger.Error("转换响应失败", zap.Error(err))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
@@ -191,7 +193,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
|
||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -226,34 +228,50 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
|
||||
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
h.logger.Error("流读取错误", zap.String("error", event.Error.Error()))
|
||||
h.logger.Error("流读取错误", zap.Error(event.Error))
|
||||
break
|
||||
}
|
||||
if event.Done {
|
||||
// flush 转换器
|
||||
chunks := streamConverter.Flush()
|
||||
for _, chunk := range chunks {
|
||||
writer.Write(chunk)
|
||||
writer.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
chunks := streamConverter.ProcessChunk(event.Data)
|
||||
for _, chunk := range chunks {
|
||||
writer.Write(chunk)
|
||||
writer.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error {
|
||||
for _, chunk := range chunks {
|
||||
if _, err := writer.Write(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writer.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isStreamRequest 判断是否流式请求
|
||||
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
|
||||
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol)
|
||||
ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if ifaceType != conversion.InterfaceTypeChat {
|
||||
return false
|
||||
}
|
||||
@@ -272,7 +290,7 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
|
||||
// 从数据库查询所有启用的模型
|
||||
models, err := h.providerService.ListEnabledModels()
|
||||
if err != nil {
|
||||
h.logger.Error("查询启用模型失败", zap.String("error", err.Error()))
|
||||
h.logger.Error("查询启用模型失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
|
||||
return
|
||||
}
|
||||
@@ -294,7 +312,7 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
|
||||
// 使用 adapter 编码返回
|
||||
body, err := adapter.EncodeModelsResponse(modelList)
|
||||
if err != nil {
|
||||
h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error()))
|
||||
h.logger.Error("编码 Models 响应失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
|
||||
return
|
||||
}
|
||||
@@ -342,8 +360,13 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
|
||||
|
||||
// writeConversionError 写入转换错误
|
||||
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
|
||||
if convErr, ok := err.(*conversion.ConversionError); ok {
|
||||
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol)
|
||||
var convErr *conversion.ConversionError
|
||||
if errors.As(err, &convErr) {
|
||||
body, statusCode, encodeErr := h.engine.EncodeError(convErr, clientProtocol)
|
||||
if encodeErr != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": encodeErr.Error()})
|
||||
return
|
||||
}
|
||||
c.Data(statusCode, "application/json", body)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -8,27 +8,26 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/conversion/openai"
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/tests/mocks"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/conversion/openai"
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/provider"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
"nex/backend/tests/mocks"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
|
||||
|
||||
func setupProxyEngine(t *testing.T) *conversion.ConversionEngine {
|
||||
t.Helper()
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
@@ -844,7 +843,8 @@ func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
assert.Len(t, data, 2)
|
||||
|
||||
first := data[0].(map[string]interface{})
|
||||
first, ok2 := data[0].(map[string]interface{})
|
||||
require.True(t, ok2)
|
||||
assert.Equal(t, "openai/gpt-4", first["id"])
|
||||
}
|
||||
|
||||
@@ -918,7 +918,7 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||
var req map[string]interface{}
|
||||
json.Unmarshal(spec.Body, &req)
|
||||
require.NoError(t, json.Unmarshal(spec.Body, &req))
|
||||
assert.Equal(t, "gpt-4", req["model"])
|
||||
|
||||
return &conversion.HTTPResponseSpec{
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// StatsHandler 统计处理器
|
||||
|
||||
@@ -51,6 +51,7 @@ type Client struct {
|
||||
}
|
||||
|
||||
// ProviderClient 供应商客户端接口
|
||||
//
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
|
||||
type ProviderClient interface {
|
||||
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
|
||||
@@ -141,7 +142,10 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
cancel()
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
errBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d,读取错误响应失败: %w", resp.StatusCode, readErr)
|
||||
}
|
||||
if len(errBody) > 0 {
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
@@ -184,7 +188,7 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
if isNetworkError(err) {
|
||||
c.logger.Error("流网络错误", zap.String("error", err.Error()))
|
||||
c.logger.Error("流网络错误", zap.Error(err))
|
||||
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
|
||||
} else {
|
||||
c.logger.Error("流读取错误", zap.Error(err))
|
||||
|
||||
@@ -41,7 +41,8 @@ func TestClient_Send_Success(t *testing.T) {
|
||||
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
|
||||
_, err := w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
@@ -65,7 +66,8 @@ func TestClient_Send_Success(t *testing.T) {
|
||||
func TestClient_Send_ErrorResponse(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
|
||||
_, err := w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
@@ -140,12 +142,15 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
|
||||
_, err := w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
|
||||
_, err = w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
w.Write([]byte("data: [DONE]\n\n"))
|
||||
_, err = w.Write([]byte("data: [DONE]\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}))
|
||||
@@ -165,11 +170,12 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
|
||||
var dataEvents [][]byte
|
||||
var doneEvents int
|
||||
for event := range eventChan {
|
||||
if event.Done {
|
||||
switch {
|
||||
case event.Done:
|
||||
doneEvents++
|
||||
} else if event.Error != nil {
|
||||
case event.Error != nil:
|
||||
t.Fatalf("unexpected error: %v", event.Error)
|
||||
} else {
|
||||
default:
|
||||
dataEvents = append(dataEvents, event.Data)
|
||||
}
|
||||
}
|
||||
@@ -215,7 +221,8 @@ func TestClient_Send_EmptyBody(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"result":"ok"}`))
|
||||
_, err := w.Write([]byte(`{"result":"ok"}`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
@@ -238,10 +245,12 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
|
||||
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.Write([]byte("data: [DONE]\n\n"))
|
||||
_, err = w.Write([]byte("data: [DONE]\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}))
|
||||
@@ -261,11 +270,12 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
|
||||
var dataCount int
|
||||
var doneCount int
|
||||
for event := range eventChan {
|
||||
if event.Done {
|
||||
switch {
|
||||
case event.Done:
|
||||
doneCount++
|
||||
} else if event.Error != nil {
|
||||
case event.Error != nil:
|
||||
t.Fatalf("unexpected error: %v", event.Error)
|
||||
} else {
|
||||
default:
|
||||
dataCount++
|
||||
}
|
||||
}
|
||||
@@ -279,10 +289,12 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
|
||||
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
w.Write([]byte("data: [DONE]\n\n"))
|
||||
_, err = w.Write([]byte("data: [DONE]\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}))
|
||||
@@ -364,13 +376,14 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
|
||||
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if hijacker, ok := w.(http.Hijacker); ok {
|
||||
conn, _, _ := hijacker.Hijack()
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
require.NoError(t, conn.Close())
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -3,10 +3,11 @@ package repository
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -3,13 +3,13 @@ package repository
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
testHelpers "nex/backend/tests"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) *gorm.DB {
|
||||
|
||||
@@ -3,11 +3,11 @@ package repository
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type statsRepository struct {
|
||||
@@ -19,8 +19,8 @@ func NewStatsRepository(db *gorm.DB) StatsRepository {
|
||||
}
|
||||
|
||||
func (r *statsRepository) Record(providerID, modelName string) error {
|
||||
today := time.Now().Format("2006-01-02")
|
||||
todayTime, _ := time.Parse("2006-01-02", today)
|
||||
now := time.Now()
|
||||
todayTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
|
||||
stats := config.UsageStats{
|
||||
ProviderID: providerID,
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
"errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
type modelService struct {
|
||||
@@ -108,7 +112,11 @@ func (s *modelService) Delete(id string) error {
|
||||
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
|
||||
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||
if err != nil {
|
||||
return nil // 未找到,不重复
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil // 未找到,不重复
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
if excludeID != "" && existing.ID == excludeID {
|
||||
return nil // 排除自身
|
||||
|
||||
@@ -3,10 +3,10 @@ package service
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"nex/backend/pkg/modelid"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/pkg/modelid"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -4,10 +4,11 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
@@ -34,7 +35,9 @@ func NewRoutingCache(
|
||||
|
||||
func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
|
||||
if v, ok := c.providers.Load(id); ok {
|
||||
return v.(*domain.Provider), nil
|
||||
if provider, ok := v.(*domain.Provider); ok {
|
||||
return provider, nil
|
||||
}
|
||||
}
|
||||
|
||||
provider, err := c.providerRepo.GetByID(id)
|
||||
@@ -43,7 +46,9 @@ func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
|
||||
}
|
||||
|
||||
if v, ok := c.providers.Load(id); ok {
|
||||
return v.(*domain.Provider), nil
|
||||
if provider, ok := v.(*domain.Provider); ok {
|
||||
return provider, nil
|
||||
}
|
||||
}
|
||||
|
||||
c.providers.Store(id, provider)
|
||||
@@ -54,7 +59,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
|
||||
key := providerID + "/" + modelName
|
||||
|
||||
if v, ok := c.models.Load(key); ok {
|
||||
return v.(*domain.Model), nil
|
||||
if model, ok := v.(*domain.Model); ok {
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
|
||||
model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||
@@ -63,7 +70,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
|
||||
}
|
||||
|
||||
if v, ok := c.models.Load(key); ok {
|
||||
return v.(*domain.Model), nil
|
||||
if model, ok := v.(*domain.Model); ok {
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
|
||||
c.models.Store(key, model)
|
||||
@@ -97,7 +106,12 @@ func (c *RoutingCache) invalidateModelsByProvider(providerID string) {
|
||||
prefix := providerID + "/"
|
||||
count := 0
|
||||
c.models.Range(func(key, value interface{}) bool {
|
||||
if strings.HasPrefix(key.(string), prefix) {
|
||||
keyStr, ok := key.(string)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.HasPrefix(keyStr, prefix) {
|
||||
c.models.Delete(key)
|
||||
count++
|
||||
}
|
||||
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
type mockModelRepo struct {
|
||||
@@ -189,7 +189,8 @@ func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
|
||||
|
||||
var openaiCount, anthropicCount int
|
||||
cache.models.Range(func(key, value interface{}) bool {
|
||||
if key.(string) == "anthropic/claude" {
|
||||
keyStr, ok := key.(string)
|
||||
if ok && keyStr == "anthropic/claude" {
|
||||
anthropicCount++
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
type routingService struct {
|
||||
|
||||
@@ -3,12 +3,12 @@ package service
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
)
|
||||
|
||||
func TestProviderService_Update(t *testing.T) {
|
||||
@@ -133,7 +133,9 @@ func TestStatsService_Aggregate_Default(t *testing.T) {
|
||||
|
||||
totalCount := 0
|
||||
for _, r := range result {
|
||||
totalCount += r["request_count"].(int)
|
||||
count, ok := r["request_count"].(int)
|
||||
require.True(t, ok)
|
||||
totalCount += count
|
||||
}
|
||||
assert.Equal(t, 15, totalCount)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -13,8 +16,6 @@ import (
|
||||
|
||||
testHelpers "nex/backend/tests"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
@@ -134,7 +135,7 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
|
||||
@@ -148,7 +149,7 @@ func TestProviderService_Create_InvalidID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
@@ -160,7 +161,7 @@ func TestProviderService_Create_ValidID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
@@ -176,7 +177,7 @@ func TestModelService_Update_DuplicateModelName(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
@@ -202,7 +203,7 @@ func TestModelService_Update_ModelNotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
err := svc.Update("nonexistent-id", map[string]interface{}{
|
||||
@@ -215,7 +216,7 @@ func TestModelService_Update_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
@@ -241,7 +242,7 @@ func TestProviderService_Update_ImmutableID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
@@ -259,7 +260,7 @@ func TestProviderService_Update_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
@@ -318,7 +319,8 @@ func TestStatsService_Aggregate_ByModel(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
buffer := NewStatsBuffer(statsRepo, zap.NewNop()); svc := NewStatsService(statsRepo, buffer)
|
||||
buffer := NewStatsBuffer(statsRepo, zap.NewNop())
|
||||
svc := NewStatsService(statsRepo, buffer)
|
||||
|
||||
result := svc.Aggregate(tt.stats, "model")
|
||||
|
||||
@@ -379,7 +381,8 @@ func TestStatsService_Aggregate_ByDate(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
buffer := NewStatsBuffer(statsRepo, zap.NewNop()); svc := NewStatsService(statsRepo, buffer)
|
||||
buffer := NewStatsBuffer(statsRepo, zap.NewNop())
|
||||
svc := NewStatsService(statsRepo, buffer)
|
||||
|
||||
result := svc.Aggregate(tt.stats, "date")
|
||||
|
||||
@@ -448,7 +451,7 @@ func TestProviderService_List_APIKeyNotMasked(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
|
||||
@@ -474,7 +477,7 @@ func TestModelService_ConcurrentCreate(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
@@ -6,9 +6,10 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/repository"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/repository"
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
@@ -67,13 +68,21 @@ func (b *StatsBuffer) Increment(providerID, modelName string) {
|
||||
|
||||
var counter *int64
|
||||
if v, ok := b.counters.Load(key); ok {
|
||||
counter = v.(*int64)
|
||||
if existing, ok := v.(*int64); ok {
|
||||
counter = existing
|
||||
} else {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
val := int64(0)
|
||||
counter = &val
|
||||
actual, loaded := b.counters.LoadOrStore(key, counter)
|
||||
if loaded {
|
||||
counter = actual.(*int64)
|
||||
existing, ok := actual.(*int64)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
counter = existing
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,13 +126,20 @@ func (b *StatsBuffer) flush() {
|
||||
|
||||
var entries []statEntry
|
||||
b.counters.Range(func(key, value interface{}) bool {
|
||||
keyStr := key.(string)
|
||||
keyStr, ok := key.(string)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
parts := strings.Split(keyStr, "/")
|
||||
if len(parts) != 3 {
|
||||
return true
|
||||
}
|
||||
|
||||
counter := value.(*int64)
|
||||
counter, ok := value.(*int64)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
count := atomic.SwapInt64(counter, 0)
|
||||
|
||||
if count > 0 {
|
||||
@@ -143,8 +159,17 @@ func (b *StatsBuffer) flush() {
|
||||
|
||||
success := 0
|
||||
for _, entry := range entries {
|
||||
date, _ := time.Parse("2006-01-02", entry.date)
|
||||
err := b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
|
||||
date, err := time.Parse("2006-01-02", entry.date)
|
||||
if err != nil {
|
||||
b.logger.Error("解析统计日期失败",
|
||||
zap.String("provider_id", entry.providerID),
|
||||
zap.String("model_name", entry.modelName),
|
||||
zap.String("date", entry.date),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
err = b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
|
||||
if err != nil {
|
||||
b.logger.Error("批量更新统计失败",
|
||||
zap.String("provider_id", entry.providerID),
|
||||
@@ -154,8 +179,10 @@ func (b *StatsBuffer) flush() {
|
||||
|
||||
key := entry.providerID + "/" + entry.modelName + "/" + entry.date
|
||||
if v, ok := b.counters.Load(key); ok {
|
||||
counter := v.(*int64)
|
||||
atomic.AddInt64(counter, entry.count)
|
||||
counter, ok := v.(*int64)
|
||||
if ok {
|
||||
atomic.AddInt64(counter, entry.count)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
success++
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
type mockStatsRepo struct {
|
||||
@@ -58,8 +58,10 @@ func TestStatsBuffer_Increment(t *testing.T) {
|
||||
|
||||
var count int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
count += atomic.LoadInt64(counter)
|
||||
counter, ok := value.(*int64)
|
||||
if ok {
|
||||
count += atomic.LoadInt64(counter)
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(3), count)
|
||||
@@ -82,8 +84,10 @@ func TestStatsBuffer_ConcurrentIncrement(t *testing.T) {
|
||||
|
||||
var count int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
count = atomic.LoadInt64(counter)
|
||||
counter, ok := value.(*int64)
|
||||
if ok {
|
||||
count = atomic.LoadInt64(counter)
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(100), count)
|
||||
@@ -161,8 +165,10 @@ func TestStatsBuffer_SwapInt64(t *testing.T) {
|
||||
|
||||
var beforeCount int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
beforeCount = atomic.LoadInt64(counter)
|
||||
counter, ok := value.(*int64)
|
||||
if ok {
|
||||
beforeCount = atomic.LoadInt64(counter)
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(2), beforeCount)
|
||||
@@ -171,8 +177,10 @@ func TestStatsBuffer_SwapInt64(t *testing.T) {
|
||||
|
||||
var afterCount int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
afterCount = atomic.LoadInt64(counter)
|
||||
counter, ok := value.(*int64)
|
||||
if ok {
|
||||
afterCount = atomic.LoadInt64(counter)
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(0), afterCount)
|
||||
@@ -190,8 +198,10 @@ func TestStatsBuffer_FailRetry(t *testing.T) {
|
||||
|
||||
var count int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
count = atomic.LoadInt64(counter)
|
||||
counter, ok := value.(*int64)
|
||||
if ok {
|
||||
count = atomic.LoadInt64(counter)
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(2), count)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
@@ -70,22 +71,11 @@ func AsAppError(err error) (*AppError, bool) {
|
||||
if err == nil {
|
||||
return nil, false
|
||||
}
|
||||
var appErr *AppError
|
||||
if ok := is(err, &appErr); ok {
|
||||
return appErr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func is(err error, target interface{}) bool {
|
||||
// 简单的类型断言
|
||||
if e, ok := err.(*AppError); ok {
|
||||
// 直接赋值
|
||||
switch t := target.(type) {
|
||||
case **AppError:
|
||||
*t = e
|
||||
return true
|
||||
}
|
||||
var appErr *AppError
|
||||
if !stderrors.As(err, &appErr) {
|
||||
return nil, false
|
||||
}
|
||||
return false
|
||||
|
||||
return appErr, true
|
||||
}
|
||||
|
||||
@@ -104,7 +104,8 @@ func TestPredefinedErrors(t *testing.T) {
|
||||
|
||||
func TestAsAppError(t *testing.T) {
|
||||
t.Run("nil输入", func(t *testing.T) {
|
||||
_, ok := AsAppError(nil)
|
||||
appErr, ok := AsAppError(nil)
|
||||
assert.Nil(t, appErr)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
@@ -122,7 +123,8 @@ func TestAsAppError(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("非AppError类型", func(t *testing.T) {
|
||||
_, ok := AsAppError(errors.New("普通错误"))
|
||||
appErr, ok := AsAppError(errors.New("普通错误"))
|
||||
assert.Nil(t, appErr)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestNew_StdoutOnly(t *testing.T) {
|
||||
|
||||
func TestNew_WithFileOutput(t *testing.T) {
|
||||
dir := filepath.Join(os.TempDir(), "nex-logger-test")
|
||||
os.MkdirAll(dir, 0755)
|
||||
require.NoError(t, os.MkdirAll(dir, 0o755))
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
logger, err := New(Config{
|
||||
@@ -81,7 +81,7 @@ func TestParseLevel(t *testing.T) {
|
||||
{"info", true},
|
||||
{"warn", true},
|
||||
{"error", true},
|
||||
{"", true}, // 默认为 info
|
||||
{"", true}, // 默认为 info
|
||||
{"invalid", true}, // 默认为 info
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -22,9 +22,9 @@ func newRotateWriter(cfg Config) *lumberjack.Logger {
|
||||
|
||||
return &lumberjack.Logger{
|
||||
Filename: logFilePath(cfg.Path),
|
||||
MaxSize: maxSize, // MB
|
||||
MaxSize: maxSize, // MB
|
||||
MaxBackups: maxBackups,
|
||||
MaxAge: maxAge, // days
|
||||
MaxAge: maxAge, // days
|
||||
Compress: cfg.Compress,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
)
|
||||
|
||||
func TestLoadConfig_DefaultValues(t *testing.T) {
|
||||
@@ -72,7 +72,7 @@ log:
|
||||
max_age: 7
|
||||
compress: false
|
||||
`
|
||||
err := os.WriteFile(configPath, []byte(yamlContent), 0644)
|
||||
err := os.WriteFile(configPath, []byte(yamlContent), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := config.LoadConfigFromPath(configPath)
|
||||
@@ -103,7 +103,7 @@ server:
|
||||
log:
|
||||
level: warn
|
||||
`
|
||||
err := os.WriteFile(configPath, []byte(yamlContent), 0644)
|
||||
err := os.WriteFile(configPath, []byte(yamlContent), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Setenv("NEX_SERVER_PORT", "9000")
|
||||
@@ -147,7 +147,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
|
||||
}
|
||||
defer func() {
|
||||
if originalConfig != nil {
|
||||
_ = os.WriteFile(configPath, originalConfig, 0644)
|
||||
require.NoError(t, os.WriteFile(configPath, originalConfig, 0o600))
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -11,20 +11,21 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/handler"
|
||||
"nex/backend/internal/handler/middleware"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
openaiConv "nex/backend/internal/conversion/openai"
|
||||
"nex/backend/internal/handler"
|
||||
"nex/backend/internal/handler/middleware"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -39,7 +40,8 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 默认返回成功,由各测试 case 覆盖
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"error":"not mocked"}`))
|
||||
_, err := w.Write([]byte(`{"error":"not mocked"}`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
db := setupTestDB(t)
|
||||
@@ -124,7 +126,6 @@ func createProviderAndModel(t *testing.T, r *gin.Engine, providerID, protocol, m
|
||||
require.Equal(t, 201, w.Code)
|
||||
|
||||
modelBody, _ := json.Marshal(map[string]string{
|
||||
|
||||
"provider_id": providerID,
|
||||
"model_name": modelName,
|
||||
})
|
||||
@@ -143,9 +144,10 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
|
||||
// 配置上游返回 Anthropic 格式响应
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 验证请求被转换为 Anthropic 格式
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req map[string]any
|
||||
json.Unmarshal(body, &req)
|
||||
require.NoError(t, json.Unmarshal(body, &req))
|
||||
|
||||
assert.Equal(t, "/v1/messages", r.URL.Path)
|
||||
assert.Contains(t, r.Header.Get("Content-Type"), "application/json")
|
||||
@@ -166,7 +168,7 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
require.NoError(t, json.NewEncoder(w).Encode(resp))
|
||||
})
|
||||
|
||||
createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
|
||||
@@ -189,13 +191,16 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var resp map[string]any
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||
assert.Equal(t, "chat.completion", resp["object"])
|
||||
|
||||
choices := resp["choices"].([]any)
|
||||
choices, ok := resp["choices"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, choices, 1)
|
||||
choice := choices[0].(map[string]any)
|
||||
msg := choice["message"].(map[string]any)
|
||||
choice, ok := choices[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
msg, ok := choice["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Contains(t, msg["content"], "Hello from Anthropic!")
|
||||
}
|
||||
|
||||
@@ -203,9 +208,10 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
|
||||
r, _, upstream := setupConversionTest(t)
|
||||
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req map[string]any
|
||||
json.Unmarshal(body, &req)
|
||||
require.NoError(t, json.Unmarshal(body, &req))
|
||||
|
||||
assert.Equal(t, "/chat/completions", r.URL.Path)
|
||||
assert.Contains(t, r.Header.Get("Authorization"), "Bearer test-key")
|
||||
@@ -229,7 +235,7 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
require.NoError(t, json.NewEncoder(w).Encode(resp))
|
||||
})
|
||||
|
||||
createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
|
||||
@@ -252,12 +258,14 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var resp map[string]any
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||
assert.Equal(t, "message", resp["type"])
|
||||
|
||||
content := resp["content"].([]any)
|
||||
content, ok := resp["content"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 1)
|
||||
block := content[0].(map[string]any)
|
||||
block, ok2 := content[0].(map[string]any)
|
||||
require.True(t, ok2)
|
||||
assert.Contains(t, block["text"], "Hello from OpenAI!")
|
||||
}
|
||||
|
||||
@@ -269,21 +277,23 @@ func TestConversion_OpenAIToOpenAI_Passthrough(t *testing.T) {
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/chat/completions", r.URL.Path)
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req map[string]any
|
||||
json.Unmarshal(body, &req)
|
||||
require.NoError(t, json.Unmarshal(body, &req))
|
||||
// Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名
|
||||
assert.Equal(t, "gpt-4", req["model"])
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// 上游返回上游模型名
|
||||
w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`))
|
||||
_, err = w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "openai_p/gpt-4", // 客户端发送统一 ID
|
||||
"model": "openai_p/gpt-4", // 客户端发送统一 ID
|
||||
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
@@ -304,21 +314,23 @@ func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) {
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/v1/messages", r.URL.Path)
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req map[string]any
|
||||
json.Unmarshal(body, &req)
|
||||
require.NoError(t, json.Unmarshal(body, &req))
|
||||
// Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名
|
||||
assert.Equal(t, "claude-3-opus", req["model"])
|
||||
|
||||
// 上游返回上游模型名
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`))
|
||||
_, err = w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "anthropic_p/claude-3-opus", // 客户端发送统一 ID
|
||||
"model": "anthropic_p/claude-3-opus", // 客户端发送统一 ID
|
||||
"max_tokens": 1024,
|
||||
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
||||
}
|
||||
@@ -352,7 +364,8 @@ func TestConversion_OpenAIToAnthropic_Stream(t *testing.T) {
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
for _, e := range events {
|
||||
w.Write([]byte(e))
|
||||
_, err := w.Write([]byte(e))
|
||||
require.NoError(t, err)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
@@ -393,7 +406,8 @@ func TestConversion_AnthropicToOpenAI_Stream(t *testing.T) {
|
||||
"data: [DONE]\n\n",
|
||||
}
|
||||
for _, e := range events {
|
||||
w.Write([]byte(e))
|
||||
_, err := w.Write([]byte(e))
|
||||
require.NoError(t, err)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
@@ -447,11 +461,13 @@ func TestConversion_Models_CrossProtocol(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
var anthropicResp map[string]any
|
||||
json.Unmarshal(anthropicBody, &anthropicResp)
|
||||
data := anthropicResp["data"].([]any)
|
||||
require.NoError(t, json.Unmarshal(anthropicBody, &anthropicResp))
|
||||
data, okd := anthropicResp["data"].([]any)
|
||||
require.True(t, okd)
|
||||
assert.Len(t, data, 2)
|
||||
|
||||
first := data[0].(map[string]any)
|
||||
first, okf := data[0].(map[string]any)
|
||||
require.True(t, okf)
|
||||
assert.Equal(t, "gpt-4", first["id"])
|
||||
assert.Equal(t, "model", first["type"])
|
||||
|
||||
@@ -466,11 +482,12 @@ func TestConversion_Models_CrossProtocol(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiResp map[string]any
|
||||
json.Unmarshal(openaiBody, &err)
|
||||
json.Unmarshal(openaiBody, &openaiResp)
|
||||
oaiData := openaiResp["data"].([]any)
|
||||
require.NoError(t, json.Unmarshal(openaiBody, &openaiResp))
|
||||
oaiData, oki := openaiResp["data"].([]any)
|
||||
require.True(t, oki)
|
||||
assert.Len(t, oaiData, 1)
|
||||
firstOai := oaiData[0].(map[string]any)
|
||||
firstOai, okf2 := oaiData[0].(map[string]any)
|
||||
require.True(t, okf2)
|
||||
assert.Equal(t, "claude-3-opus", firstOai["id"])
|
||||
}
|
||||
|
||||
@@ -537,7 +554,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
|
||||
require.Equal(t, 201, w.Code)
|
||||
|
||||
var created map[string]any
|
||||
json.Unmarshal(w.Body.Bytes(), &created)
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &created))
|
||||
assert.Equal(t, "sk-test", created["api_key"])
|
||||
|
||||
// 获取时应包含 protocol
|
||||
@@ -547,7 +564,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var fetched map[string]any
|
||||
json.Unmarshal(w.Body.Bytes(), &fetched)
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &fetched))
|
||||
assert.Equal(t, "anthropic", fetched["protocol"])
|
||||
}
|
||||
|
||||
@@ -570,11 +587,13 @@ func TestConversion_ProviderDefaultProtocol(t *testing.T) {
|
||||
require.Equal(t, 201, w.Code)
|
||||
|
||||
var created map[string]any
|
||||
json.Unmarshal(w.Body.Bytes(), &created)
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &created))
|
||||
assert.Equal(t, "openai", created["protocol"])
|
||||
}
|
||||
|
||||
// Suppress unused imports
|
||||
var _ = fmt.Sprintf
|
||||
var _ = strings.Contains
|
||||
var _ = time.Second
|
||||
var (
|
||||
_ = fmt.Sprintf
|
||||
_ = strings.Contains
|
||||
_ = time.Second
|
||||
)
|
||||
|
||||
@@ -12,19 +12,20 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
openaiConv "nex/backend/internal/conversion/openai"
|
||||
"nex/backend/internal/handler"
|
||||
"nex/backend/internal/handler/middleware"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
openaiConv "nex/backend/internal/conversion/openai"
|
||||
)
|
||||
|
||||
func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
|
||||
@@ -33,7 +34,8 @@ func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"error":"not mocked"}`))
|
||||
_, err := w.Write([]byte(`{"error":"not mocked"}`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
db := setupTestDB(t)
|
||||
@@ -115,11 +117,12 @@ func parseSSEEvents(body string) []map[string]string {
|
||||
var currentEvent, currentData string
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
switch {
|
||||
case strings.HasPrefix(line, "event: "):
|
||||
currentEvent = strings.TrimPrefix(line, "event: ")
|
||||
} else if strings.HasPrefix(line, "data: ") {
|
||||
case strings.HasPrefix(line, "data: "):
|
||||
currentData = strings.TrimPrefix(line, "data: ")
|
||||
} else if line == "" && (currentEvent != "" || currentData != "") {
|
||||
case line == "" && (currentEvent != "" || currentData != ""):
|
||||
events = append(events, map[string]string{
|
||||
"event": currentEvent,
|
||||
"data": currentData,
|
||||
@@ -157,21 +160,21 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) {
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.Equal(t, "/chat/completions", req.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "chatcmpl-e2e-001",
|
||||
"object": "chat.completion",
|
||||
"created": 1700000000,
|
||||
"model": "gpt-4o",
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"message": map[string]any{"role": "assistant", "content": "你好!我是AI助手。"},
|
||||
"index": 0,
|
||||
"message": map[string]any{"role": "assistant", "content": "你好!我是AI助手。"},
|
||||
"finish_reason": "stop",
|
||||
"logprobs": nil,
|
||||
"logprobs": nil,
|
||||
}},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": 15, "completion_tokens": 10, "total_tokens": 25,
|
||||
},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||
|
||||
@@ -210,21 +213,23 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) {
|
||||
func TestE2E_OpenAI_NonStream_MultiTurn(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
var reqBody map[string]any
|
||||
json.Unmarshal(body, &reqBody)
|
||||
msgs := reqBody["messages"].([]any)
|
||||
require.NoError(t, json.Unmarshal(body, &reqBody))
|
||||
msgs, ok := reqBody["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.GreaterOrEqual(t, len(msgs), 3)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "chatcmpl-e2e-002", "object": "chat.completion", "created": 1700000001, "model": "gpt-4o",
|
||||
"choices": []map[string]any{{
|
||||
"index": 0, "message": map[string]any{"role": "assistant", "content": "Go语言的interface是隐式实现的。"},
|
||||
"finish_reason": "stop", "logprobs": nil,
|
||||
}},
|
||||
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||
|
||||
@@ -252,7 +257,7 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "chatcmpl-e2e-004", "object": "chat.completion", "created": 1700000003, "model": "gpt-4o",
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
@@ -272,7 +277,7 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
|
||||
"logprobs": nil,
|
||||
}},
|
||||
"usage": map[string]any{"prompt_tokens": 80, "completion_tokens": 18, "total_tokens": 98},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||
|
||||
@@ -286,9 +291,9 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
|
||||
"function": map[string]any{
|
||||
"name": "get_weather", "description": "获取天气",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"properties": map[string]any{"city": map[string]any{"type": "string"}},
|
||||
"required": []string{"city"},
|
||||
"required": []string{"city"},
|
||||
},
|
||||
},
|
||||
}},
|
||||
@@ -319,22 +324,22 @@ func TestE2E_OpenAI_NonStream_MaxTokens_Length(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "chatcmpl-e2e-014", "object": "chat.completion", "created": 1700000014, "model": "gpt-4o",
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"message": map[string]any{"role": "assistant", "content": "人工智能起源于1950年代..."},
|
||||
"index": 0,
|
||||
"message": map[string]any{"role": "assistant", "content": "人工智能起源于1950年代..."},
|
||||
"finish_reason": "length",
|
||||
"logprobs": nil,
|
||||
}},
|
||||
"usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 30, "total_tokens": 50},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "openai_p/gpt-4o",
|
||||
"messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}},
|
||||
"model": "openai_p/gpt-4o",
|
||||
"messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}},
|
||||
"max_tokens": 30,
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -353,11 +358,11 @@ func TestE2E_OpenAI_NonStream_UsageWithReasoning(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "chatcmpl-e2e-022", "object": "chat.completion", "created": 1700000022, "model": "o3",
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"message": map[string]any{"role": "assistant", "content": "答案是61。"},
|
||||
"index": 0,
|
||||
"message": map[string]any{"role": "assistant", "content": "答案是61。"},
|
||||
"finish_reason": "stop",
|
||||
"logprobs": nil,
|
||||
}},
|
||||
@@ -365,12 +370,12 @@ func TestE2E_OpenAI_NonStream_UsageWithReasoning(t *testing.T) {
|
||||
"prompt_tokens": 35, "completion_tokens": 48, "total_tokens": 83,
|
||||
"completion_tokens_details": map[string]any{"reasoning_tokens": 20},
|
||||
},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "o3", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "openai_p/o3",
|
||||
"model": "openai_p/o3",
|
||||
"messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -393,12 +398,12 @@ func TestE2E_OpenAI_NonStream_Refusal(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "chatcmpl-e2e-007", "object": "chat.completion", "created": 1700000007, "model": "gpt-4o",
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"message": map[string]any{
|
||||
"role": "assistant",
|
||||
"role": "assistant",
|
||||
"content": nil,
|
||||
"refusal": "抱歉,我无法提供涉及危险活动的信息。",
|
||||
},
|
||||
@@ -406,12 +411,12 @@ func TestE2E_OpenAI_NonStream_Refusal(t *testing.T) {
|
||||
"logprobs": nil,
|
||||
}},
|
||||
"usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 35, "total_tokens": 47},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "openai_p/gpt-4o",
|
||||
"model": "openai_p/gpt-4o",
|
||||
"messages": []map[string]any{{"role": "user", "content": "做坏事"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -453,9 +458,9 @@ func TestE2E_OpenAI_Stream_Text(t *testing.T) {
|
||||
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "openai_p/gpt-4o",
|
||||
"model": "openai_p/gpt-4o",
|
||||
"messages": []map[string]any{{"role": "user", "content": "你好"}},
|
||||
"stream": true,
|
||||
"stream": true,
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
|
||||
@@ -497,14 +502,14 @@ func TestE2E_OpenAI_Stream_ToolCalls(t *testing.T) {
|
||||
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "openai_p/gpt-4o",
|
||||
"model": "openai_p/gpt-4o",
|
||||
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
|
||||
"tools": []map[string]any{{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather", "description": "获取天气",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"properties": map[string]any{"city": map[string]any{"type": "string"}},
|
||||
},
|
||||
},
|
||||
@@ -546,9 +551,9 @@ func TestE2E_OpenAI_Stream_WithUsage(t *testing.T) {
|
||||
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "openai_p/gpt-4o",
|
||||
"model": "openai_p/gpt-4o",
|
||||
"messages": []map[string]any{{"role": "user", "content": "hi"}},
|
||||
"stream": true,
|
||||
"stream": true,
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
|
||||
@@ -569,14 +574,14 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "msg_e2e_001", "type": "message", "role": "assistant",
|
||||
"content": []map[string]any{
|
||||
{"type": "text", "text": "你好!我是Claude,由Anthropic开发的AI助手。"},
|
||||
},
|
||||
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": 15, "output_tokens": 25},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||
|
||||
@@ -611,24 +616,25 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) {
|
||||
func TestE2E_Anthropic_NonStream_WithSystem(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
var reqBody map[string]any
|
||||
json.Unmarshal(body, &reqBody)
|
||||
require.NoError(t, json.Unmarshal(body, &reqBody))
|
||||
assert.NotNil(t, reqBody["system"])
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "msg_e2e_003", "type": "message", "role": "assistant",
|
||||
"content": []map[string]any{{"type": "text", "text": "递归是函数调用自身。"}},
|
||||
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": 30, "output_tokens": 15},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||
"system": "你是编程助手",
|
||||
"system": "你是编程助手",
|
||||
"messages": []map[string]any{{"role": "user", "content": "什么是递归?"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -643,7 +649,7 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "msg_e2e_009", "type": "message", "role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "tool_use", "id": "toolu_e2e_009", "name": "get_weather",
|
||||
@@ -651,7 +657,7 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
|
||||
}},
|
||||
"model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": 180, "output_tokens": 42},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||
|
||||
@@ -661,9 +667,9 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
|
||||
"tools": []map[string]any{{
|
||||
"name": "get_weather", "description": "获取天气",
|
||||
"input_schema": map[string]any{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"properties": map[string]any{"city": map[string]any{"type": "string"}},
|
||||
"required": []string{"city"},
|
||||
"required": []string{"city"},
|
||||
},
|
||||
}},
|
||||
"tool_choice": map[string]any{"type": "auto"},
|
||||
@@ -689,7 +695,7 @@ func TestE2E_Anthropic_NonStream_Thinking(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "msg_e2e_018", "type": "message", "role": "assistant",
|
||||
"content": []map[string]any{
|
||||
{"type": "thinking", "thinking": "这是一个逻辑推理问题..."},
|
||||
@@ -697,7 +703,7 @@ func TestE2E_Anthropic_NonStream_Thinking(t *testing.T) {
|
||||
},
|
||||
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": 95, "output_tokens": 280},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||
|
||||
@@ -724,12 +730,12 @@ func TestE2E_Anthropic_NonStream_MaxTokens(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "msg_e2e_016", "type": "message", "role": "assistant",
|
||||
"content": []map[string]any{{"type": "text", "text": "人工智能起源于..."}},
|
||||
"model": "claude-opus-4-7", "stop_reason": "max_tokens", "stop_sequence": nil,
|
||||
"model": "claude-opus-4-7", "stop_reason": "max_tokens", "stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": 22, "output_tokens": 20},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||
|
||||
@@ -752,18 +758,18 @@ func TestE2E_Anthropic_NonStream_StopSequence(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "msg_e2e_017", "type": "message", "role": "assistant",
|
||||
"content": []map[string]any{{"type": "text", "text": "1\n2\n3\n4\n"}},
|
||||
"model": "claude-opus-4-7", "stop_reason": "stop_sequence", "stop_sequence": "5",
|
||||
"model": "claude-opus-4-7", "stop_reason": "stop_sequence", "stop_sequence": "5",
|
||||
"usage": map[string]any{"input_tokens": 22, "output_tokens": 10},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
|
||||
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
|
||||
"stop_sequences": []string{"5"},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -781,19 +787,20 @@ func TestE2E_Anthropic_NonStream_StopSequence(t *testing.T) {
|
||||
func TestE2E_Anthropic_NonStream_MetadataUserID(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
var reqBody map[string]any
|
||||
json.Unmarshal(body, &reqBody)
|
||||
require.NoError(t, json.Unmarshal(body, &reqBody))
|
||||
metadata, _ := reqBody["metadata"].(map[string]any)
|
||||
assert.Equal(t, "user_12345", metadata["user_id"])
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "msg_e2e_026", "type": "message", "role": "assistant",
|
||||
"content": []map[string]any{{"type": "text", "text": "你好!"}},
|
||||
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": 12, "output_tokens": 5},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||
|
||||
@@ -814,21 +821,21 @@ func TestE2E_Anthropic_NonStream_UsageWithCache(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "msg_e2e_025", "type": "message", "role": "assistant",
|
||||
"content": []map[string]any{{"type": "text", "text": "你好!"}},
|
||||
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"usage": map[string]any{
|
||||
"input_tokens": 25, "output_tokens": 5,
|
||||
"cache_creation_input_tokens": 15, "cache_read_input_tokens": 0,
|
||||
},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||
"system": []map[string]any{{"type": "text", "text": "你是编程助手。"}},
|
||||
"system": []map[string]any{{"type": "text", "text": "你是编程助手。"}},
|
||||
"messages": []map[string]any{{"role": "user", "content": "你好"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -864,7 +871,8 @@ func TestE2E_Anthropic_Stream_Text(t *testing.T) {
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
for _, e := range events {
|
||||
w.Write([]byte(e))
|
||||
_, err := w.Write([]byte(e))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
@@ -874,7 +882,7 @@ func TestE2E_Anthropic_Stream_Text(t *testing.T) {
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||
"messages": []map[string]any{{"role": "user", "content": "你好"}},
|
||||
"stream": true,
|
||||
"stream": true,
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
|
||||
@@ -922,7 +930,7 @@ func TestE2E_Anthropic_Stream_Thinking(t *testing.T) {
|
||||
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 4096,
|
||||
"messages": []map[string]any{{"role": "user", "content": "1+1=?"}},
|
||||
"thinking": map[string]any{"type": "enabled", "budget_tokens": 1024},
|
||||
"stream": true,
|
||||
"stream": true,
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
|
||||
@@ -961,14 +969,14 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_RequestFormat(t *testing.T) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "msg_cross_001", "type": "message", "role": "assistant",
|
||||
"content": []map[string]any{{"type": "text", "text": "跨协议响应"}},
|
||||
"model": "claude-model", "stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"model": "claude-model", "stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
|
||||
})
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "anthropic_p/claude-model",
|
||||
"model": "anthropic_p/claude-model",
|
||||
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -1050,9 +1058,9 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_Stream(t *testing.T) {
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "anthropic_p/claude-model",
|
||||
"model": "anthropic_p/claude-model",
|
||||
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
||||
"stream": true,
|
||||
"stream": true,
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
|
||||
@@ -1092,7 +1100,7 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream(t *testing.T) {
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "openai_p/gpt-4", "max_tokens": 1024,
|
||||
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
||||
"stream": true,
|
||||
"stream": true,
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
|
||||
@@ -1128,7 +1136,7 @@ func TestE2E_OpenAI_ErrorResponse(t *testing.T) {
|
||||
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "nonexistent", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "openai_p/nonexistent",
|
||||
"model": "openai_p/nonexistent",
|
||||
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -1183,11 +1191,11 @@ func TestE2E_OpenAI_NonStream_ParallelToolCalls(t *testing.T) {
|
||||
"content": nil,
|
||||
"tool_calls": []map[string]any{
|
||||
{
|
||||
"id": "call_ptc_1", "type": "function",
|
||||
"id": "call_ptc_1", "type": "function",
|
||||
"function": map[string]any{"name": "get_weather", "arguments": `{"city":"北京"}`},
|
||||
},
|
||||
{
|
||||
"id": "call_ptc_2", "type": "function",
|
||||
"id": "call_ptc_2", "type": "function",
|
||||
"function": map[string]any{"name": "get_weather", "arguments": `{"city":"上海"}`},
|
||||
},
|
||||
},
|
||||
@@ -1201,7 +1209,7 @@ func TestE2E_OpenAI_NonStream_ParallelToolCalls(t *testing.T) {
|
||||
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "openai_p/gpt-4o",
|
||||
"model": "openai_p/gpt-4o",
|
||||
"messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}},
|
||||
"tools": []map[string]any{{
|
||||
"type": "function",
|
||||
@@ -1242,10 +1250,10 @@ func TestE2E_OpenAI_NonStream_StopSequence(t *testing.T) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "chatcmpl-e2e-stop", "object": "chat.completion", "created": 1700000060, "model": "gpt-4o",
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"message": map[string]any{"role": "assistant", "content": "1, 2, 3, 4, "},
|
||||
"finish_reason": "stop",
|
||||
"logprobs": nil,
|
||||
"index": 0,
|
||||
"message": map[string]any{"role": "assistant", "content": "1, 2, 3, 4, "},
|
||||
"finish_reason": "stop",
|
||||
"logprobs": nil,
|
||||
}},
|
||||
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18},
|
||||
})
|
||||
@@ -1253,9 +1261,9 @@ func TestE2E_OpenAI_NonStream_StopSequence(t *testing.T) {
|
||||
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "openai_p/gpt-4o",
|
||||
"model": "openai_p/gpt-4o",
|
||||
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
|
||||
"stop": []string{"5"},
|
||||
"stop": []string{"5"},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
|
||||
@@ -1291,7 +1299,7 @@ func TestE2E_OpenAI_NonStream_ContentFilter(t *testing.T) {
|
||||
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "openai_p/gpt-4o",
|
||||
"model": "openai_p/gpt-4o",
|
||||
"messages": []map[string]any{{"role": "user", "content": "危险内容"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -1353,21 +1361,22 @@ func TestE2E_Anthropic_NonStream_MultiToolUse(t *testing.T) {
|
||||
func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
var reqBody map[string]any
|
||||
json.Unmarshal(body, &reqBody)
|
||||
require.NoError(t, json.Unmarshal(body, &reqBody))
|
||||
tc, _ := reqBody["tool_choice"].(map[string]any)
|
||||
assert.Equal(t, "any", tc["type"])
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "msg_e2e_tca", "type": "message", "role": "assistant",
|
||||
"content": []map[string]any{
|
||||
{"type": "tool_use", "id": "toolu_tca_1", "name": "get_time", "input": map[string]any{"timezone": "Asia/Shanghai"}},
|
||||
},
|
||||
"model": "claude-opus-4-7", "stop_reason": "tool_use", "stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": 100, "output_tokens": 30},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||
|
||||
@@ -1397,20 +1406,21 @@ func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) {
|
||||
func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
var reqBody map[string]any
|
||||
json.Unmarshal(body, &reqBody)
|
||||
require.NoError(t, json.Unmarshal(body, &reqBody))
|
||||
sys, ok := reqBody["system"].([]any)
|
||||
require.True(t, ok, "system should be an array")
|
||||
require.GreaterOrEqual(t, len(sys), 1)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "msg_e2e_asys", "type": "message", "role": "assistant",
|
||||
"content": []map[string]any{{"type": "text", "text": "已收到多条系统指令。"}},
|
||||
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": 50, "output_tokens": 10},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||
|
||||
@@ -1433,21 +1443,22 @@ func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) {
|
||||
func TestE2E_Anthropic_NonStream_ToolResultMessage(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
var reqBody map[string]any
|
||||
json.Unmarshal(body, &reqBody)
|
||||
require.NoError(t, json.Unmarshal(body, &reqBody))
|
||||
msgs := reqBody["messages"].([]any)
|
||||
require.GreaterOrEqual(t, len(msgs), 3)
|
||||
lastMsg := msgs[len(msgs)-1].(map[string]any)
|
||||
assert.Equal(t, "user", lastMsg["role"])
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "msg_e2e_tr", "type": "message", "role": "assistant",
|
||||
"content": []map[string]any{{"type": "text", "text": "北京当前晴天,温度25°C。"}},
|
||||
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"model": "claude-opus-4-7", "stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": 150, "output_tokens": 20},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||
|
||||
@@ -1497,7 +1508,8 @@ func TestE2E_Anthropic_Stream_ToolCalls(t *testing.T) {
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
for _, e := range events {
|
||||
w.Write([]byte(e))
|
||||
_, err := w.Write([]byte(e))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
@@ -1559,7 +1571,7 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_NonStream_ToolCalls(t *testing.T) {
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "anthropic_p/claude-model",
|
||||
"model": "anthropic_p/claude-model",
|
||||
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
|
||||
"tools": []map[string]any{{
|
||||
"type": "function",
|
||||
@@ -1634,14 +1646,14 @@ func TestE2E_CrossProtocol_StopReasonMapping(t *testing.T) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "msg_cross_stop", "type": "message", "role": "assistant",
|
||||
"content": []map[string]any{{"type": "text", "text": "被截断的内容..."}},
|
||||
"model": "claude-model", "stop_reason": "max_tokens", "stop_sequence": nil,
|
||||
"model": "claude-model", "stop_reason": "max_tokens", "stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": 10, "output_tokens": 20},
|
||||
})
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "anthropic_p/claude-model",
|
||||
"model": "anthropic_p/claude-model",
|
||||
"messages": []map[string]any{{"role": "user", "content": "长文"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -1659,9 +1671,10 @@ func TestE2E_CrossProtocol_StopReasonMapping(t *testing.T) {
|
||||
func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) {
|
||||
r, upstream := setupE2ETest(t)
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
var reqBody map[string]any
|
||||
json.Unmarshal(body, &reqBody)
|
||||
require.NoError(t, json.Unmarshal(body, &reqBody))
|
||||
msgs := reqBody["messages"].([]any)
|
||||
require.GreaterOrEqual(t, len(msgs), 3)
|
||||
toolMsg := msgs[2].(map[string]any)
|
||||
@@ -1669,16 +1682,16 @@ func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) {
|
||||
assert.Equal(t, "call_e2e_001", toolMsg["tool_call_id"])
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "chatcmpl-e2e-tr", "object": "chat.completion", "created": 1700000080, "model": "gpt-4o",
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"message": map[string]any{"role": "assistant", "content": "北京当前晴天,温度25°C。"},
|
||||
"index": 0,
|
||||
"message": map[string]any{"role": "assistant", "content": "北京当前晴天,温度25°C。"},
|
||||
"finish_reason": "stop",
|
||||
"logprobs": nil,
|
||||
}},
|
||||
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
|
||||
})
|
||||
}))
|
||||
})
|
||||
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||
|
||||
@@ -1722,7 +1735,8 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream_ToolCalls(t *testing.T) {
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
for _, e := range events {
|
||||
w.Write([]byte(e))
|
||||
_, err := w.Write([]byte(e))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
@@ -1730,7 +1744,7 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream_ToolCalls(t *testing.T) {
|
||||
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "anthropic_p/claude-model",
|
||||
"model": "anthropic_p/claude-model",
|
||||
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
|
||||
"tools": []map[string]any{{
|
||||
"type": "function",
|
||||
@@ -1817,7 +1831,7 @@ func TestE2E_OpenAI_Upstream5xx_ErrorPassthrough(t *testing.T) {
|
||||
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "openai_p/gpt-4o",
|
||||
"model": "openai_p/gpt-4o",
|
||||
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -1879,7 +1893,8 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"正常\"}}\n\n",
|
||||
}
|
||||
for _, e := range events {
|
||||
w.Write([]byte(e))
|
||||
_, err := w.Write([]byte(e))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
@@ -1889,7 +1904,7 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
||||
"stream": true,
|
||||
"stream": true,
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
|
||||
@@ -1902,5 +1917,7 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
|
||||
assert.Contains(t, respBody, "正常")
|
||||
}
|
||||
|
||||
var _ = fmt.Sprintf
|
||||
var _ = time.Now
|
||||
var (
|
||||
_ = fmt.Sprintf
|
||||
_ = time.Now
|
||||
)
|
||||
|
||||
@@ -7,16 +7,17 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/handler"
|
||||
"nex/backend/internal/handler/middleware"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -97,7 +98,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 201, w.Code)
|
||||
var createdModel domain.Model
|
||||
json.Unmarshal(w.Body.Bytes(), &createdModel)
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &createdModel))
|
||||
assert.NotEmpty(t, createdModel.ID)
|
||||
|
||||
// 3. 列出 Provider
|
||||
@@ -106,7 +107,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
var providers []domain.Provider
|
||||
json.Unmarshal(w.Body.Bytes(), &providers)
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &providers))
|
||||
assert.Len(t, providers, 1)
|
||||
assert.Equal(t, "sk-test-key", providers[0].APIKey)
|
||||
|
||||
@@ -116,7 +117,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
var models []domain.Model
|
||||
json.Unmarshal(w.Body.Bytes(), &models)
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &models))
|
||||
assert.Len(t, models, 1)
|
||||
assert.Equal(t, "gpt-4", models[0].ModelName)
|
||||
|
||||
@@ -163,7 +164,7 @@ func TestAnthropic_ModelCreation(t *testing.T) {
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 201, w.Code)
|
||||
var createdModel domain.Model
|
||||
json.Unmarshal(w.Body.Bytes(), &createdModel)
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &createdModel))
|
||||
|
||||
// 验证创建成功
|
||||
w = httptest.NewRecorder()
|
||||
@@ -194,9 +195,9 @@ func TestStats_RecordingAndQuery(t *testing.T) {
|
||||
|
||||
// 直接通过 repository 记录统计(模拟代理请求后的统计记录)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
statsRepo.Record("p1", "gpt-4")
|
||||
statsRepo.Record("p1", "gpt-4")
|
||||
statsRepo.Record("p1", "gpt-4")
|
||||
require.NoError(t, statsRepo.Record("p1", "gpt-4"))
|
||||
require.NoError(t, statsRepo.Record("p1", "gpt-4"))
|
||||
require.NoError(t, statsRepo.Record("p1", "gpt-4"))
|
||||
|
||||
// 查询统计
|
||||
w = httptest.NewRecorder()
|
||||
@@ -205,7 +206,7 @@ func TestStats_RecordingAndQuery(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var stats []domain.UsageStats
|
||||
json.Unmarshal(w.Body.Bytes(), &stats)
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &stats))
|
||||
assert.Len(t, stats, 1)
|
||||
assert.Equal(t, 3, stats[0].RequestCount)
|
||||
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
)
|
||||
|
||||
// setupTestDB 创建内存 SQLite 数据库并执行 AutoMigrate。
|
||||
|
||||
@@ -3,9 +3,10 @@ package tests
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"nex/backend/internal/config"
|
||||
)
|
||||
|
||||
func TestMigration_ModelsUUIDPrimaryKey(t *testing.T) {
|
||||
|
||||
@@ -10,9 +10,10 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
domain "nex/backend/internal/domain"
|
||||
reflect "reflect"
|
||||
|
||||
domain "nex/backend/internal/domain"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
|
||||
@@ -10,9 +10,10 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
domain "nex/backend/internal/domain"
|
||||
reflect "reflect"
|
||||
|
||||
domain "nex/backend/internal/domain"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
|
||||
@@ -11,9 +11,10 @@ package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
conversion "nex/backend/internal/conversion"
|
||||
provider "nex/backend/internal/provider"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
@@ -10,9 +10,10 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
domain "nex/backend/internal/domain"
|
||||
reflect "reflect"
|
||||
|
||||
domain "nex/backend/internal/domain"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
|
||||
@@ -10,9 +10,10 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
domain "nex/backend/internal/domain"
|
||||
reflect "reflect"
|
||||
|
||||
domain "nex/backend/internal/domain"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
|
||||
@@ -10,9 +10,10 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
domain "nex/backend/internal/domain"
|
||||
reflect "reflect"
|
||||
|
||||
domain "nex/backend/internal/domain"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
|
||||
@@ -10,10 +10,11 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
domain "nex/backend/internal/domain"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
domain "nex/backend/internal/domain"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
|
||||
@@ -10,10 +10,11 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
domain "nex/backend/internal/domain"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
domain "nex/backend/internal/domain"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ func TestConstraint_UniqueProviderModel(t *testing.T) {
|
||||
}
|
||||
err = db.Create(&model2).Error
|
||||
assert.Error(t, err, "创建相同 (provider_id, model_name) 的 model 应失败")
|
||||
assert.True(t, errors.Is(err, gorm.ErrDuplicatedKey) ||
|
||||
assert.True(t, errors.Is(err, gorm.ErrDuplicatedKey) ||
|
||||
(err != nil && (err.Error() == "Error 1062" || containsDuplicateError(err.Error()))),
|
||||
"错误应为唯一约束错误")
|
||||
}
|
||||
@@ -120,7 +120,7 @@ func TestConstraint_UniqueUsageStats(t *testing.T) {
|
||||
}
|
||||
err = db.Create(&stats2).Error
|
||||
assert.Error(t, err, "创建相同 (provider_id, model_name, date) 的 usage_stats 应失败")
|
||||
assert.True(t, errors.Is(err, gorm.ErrDuplicatedKey) ||
|
||||
assert.True(t, errors.Is(err, gorm.ErrDuplicatedKey) ||
|
||||
(err != nil && (err.Error() == "Error 1062" || containsDuplicateError(err.Error()))),
|
||||
"错误应为唯一约束错误")
|
||||
}
|
||||
|
||||
5
lefthook.yml
Normal file
5
lefthook.yml
Normal file
@@ -0,0 +1,5 @@
|
||||
pre-commit:
|
||||
commands:
|
||||
backend-lint:
|
||||
glob: "backend/**/*.go"
|
||||
run: cd backend && go tool golangci-lint run --new-from-rev HEAD ./...
|
||||
@@ -169,15 +169,15 @@
|
||||
|
||||
- **WHEN** 应用启动
|
||||
- **THEN** SHALL 按以下顺序加载配置:
|
||||
1. 解析 CLI 参数(获取 --config 路径)
|
||||
2. 初始化配置管理器
|
||||
3. 设置默认值
|
||||
4. 绑定 CLI 参数
|
||||
5. 绑定环境变量
|
||||
6. 读取配置文件
|
||||
7. 反序列化到结构体
|
||||
8. 验证配置
|
||||
9. 打印配置摘要
|
||||
1. 解析 CLI 参数(获取 --config 路径)
|
||||
2. 初始化配置管理器
|
||||
3. 设置默认值
|
||||
4. 绑定 CLI 参数
|
||||
5. 绑定环境变量
|
||||
6. 读取配置文件(不存在时自动创建)
|
||||
7. 反序列化到结构体
|
||||
8. 验证配置
|
||||
9. 打印配置摘要
|
||||
|
||||
#### Scenario: 加载失败处理
|
||||
|
||||
|
||||
@@ -31,16 +31,27 @@
|
||||
- **THEN** SHALL 测试请求转换、响应转换、流式转换
|
||||
- **THEN** SHALL 验证转换的准确性和完整性
|
||||
|
||||
#### Scenario: config 加载管道集成测试
|
||||
#### Scenario: LoadConfigFromPath 默认值验证
|
||||
|
||||
- **WHEN** 运行 config 加载管道的集成测试
|
||||
- **THEN** SHALL 验证 LoadConfigFromPath 正确加载默认值
|
||||
- **THEN** SHALL 验证环境变量(`NEX_` 前缀)覆盖默认值
|
||||
- **THEN** SHALL 验证 YAML 配置文件正确读取
|
||||
- **THEN** SHALL 验证优先级链:CLI 参数 > 环境变量 > YAML 文件 > 默认值
|
||||
- **THEN** SHALL 验证首次启动自动创建配置文件
|
||||
- **THEN** SHALL 验证 SaveConfig 后重新 LoadConfig 数据一致
|
||||
|
||||
#### Scenario: 环境变量覆盖验证
|
||||
|
||||
- **WHEN** 设置 `NEX_SERVER_PORT=9000` 和 `NEX_LOG_LEVEL=debug`
|
||||
- **THEN** SHALL 成功加载
|
||||
- **THEN** 配置值 SHALL 反映环境变量覆盖
|
||||
|
||||
#### Scenario: 自动创建配置文件验证
|
||||
|
||||
- **WHEN** 调用 `LoadConfigFromPath` 并指向不存在的文件路径
|
||||
- **THEN** SHALL 成功加载(不返回 `missing configuration for 'configPath'` 错误)
|
||||
- **THEN** SHALL 返回默认配置对象
|
||||
|
||||
#### Scenario: handler 错误分支测试
|
||||
|
||||
- **WHEN** 运行 handler 层的单元测试
|
||||
|
||||
Reference in New Issue
Block a user