feat: 增强桌面启动失败提示与测试覆盖
This commit is contained in:
@@ -4,17 +4,35 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func showError(title, message string) {
|
||||
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`,
|
||||
escapeAppleScript(message), escapeAppleScript(title))
|
||||
if err := exec.Command("osascript", "-e", script).Run(); err != nil {
|
||||
dialogLogger().Warn("显示错误对话框失败", zap.Error(err))
|
||||
func platformStartupChannels(runner commandRunner) []promptChannel {
|
||||
return []promptChannel{
|
||||
{
|
||||
name: "macos-notification",
|
||||
available: func() error {
|
||||
_, err := runner.LookPath("osascript")
|
||||
return err
|
||||
},
|
||||
run: func(req promptRequest) error {
|
||||
script := fmt.Sprintf(`display notification "%s" with title "%s" subtitle "%s"`,
|
||||
escapeAppleScript(req.message), escapeAppleScript(req.title), escapeAppleScript(req.subtitle))
|
||||
return runner.Run(promptCommandTimeout, nil, "osascript", "-e", script)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "macos-alert",
|
||||
available: func() error {
|
||||
_, err := runner.LookPath("osascript")
|
||||
return err
|
||||
},
|
||||
run: func(req promptRequest) error {
|
||||
script := fmt.Sprintf(`display alert "%s" message "%s" as critical buttons {"OK"} default button "OK"`,
|
||||
escapeAppleScript(req.title), escapeAppleScript(req.message))
|
||||
return runner.Run(promptCommandTimeout, nil, "osascript", "-e", script)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
46
backend/cmd/desktop/dialog_darwin_test.go
Normal file
46
backend/cmd/desktop/dialog_darwin_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
//go:build darwin
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDarwinStartupChannelsBuildNotificationAndAlert(t *testing.T) {
|
||||
runner := &fakeCommandRunner{paths: map[string]bool{"osascript": true}}
|
||||
channels := platformStartupChannels(runner)
|
||||
if len(channels) != 2 {
|
||||
t.Fatalf("macOS 应有 notification 和 alert 两级通道,实际: %d", len(channels))
|
||||
}
|
||||
|
||||
req := promptRequest{title: "Nex 启动失败", subtitle: "config", message: "路径 C:\\tmp 包含 \"quote\""}
|
||||
for _, channel := range channels {
|
||||
if err := channel.available(); err != nil {
|
||||
t.Fatalf("通道 %s 应可用: %v", channel.name, err)
|
||||
}
|
||||
if err := channel.run(req); err != nil {
|
||||
t.Fatalf("通道 %s 执行失败: %v", channel.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(runner.calls) != 2 {
|
||||
t.Fatalf("应执行两次 osascript,实际: %d", len(runner.calls))
|
||||
}
|
||||
if runner.calls[0].name != "osascript" || runner.calls[0].args[0] != "-e" {
|
||||
t.Fatalf("notification 命令参数错误: %#v", runner.calls[0])
|
||||
}
|
||||
if script := runner.calls[0].args[1]; !strings.Contains(script, "display notification") || !strings.Contains(script, `\\tmp`) || !strings.Contains(script, `\"quote\"`) {
|
||||
t.Fatalf("notification AppleScript 未正确构造或转义: %s", script)
|
||||
}
|
||||
if script := runner.calls[1].args[1]; !strings.Contains(script, "display alert") || !strings.Contains(script, "as critical") {
|
||||
t.Fatalf("alert AppleScript 未使用 critical 告警: %s", script)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapeAppleScript(t *testing.T) {
|
||||
got := escapeAppleScript(`C:\tmp "quote"`)
|
||||
if !strings.Contains(got, `C:\\tmp`) || !strings.Contains(got, `\"quote\"`) {
|
||||
t.Fatalf("AppleScript 转义结果错误: %s", got)
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,9 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@@ -12,56 +13,99 @@ type dialogToolType int
|
||||
|
||||
const (
|
||||
toolNone dialogToolType = iota
|
||||
toolZenity
|
||||
toolKdialog
|
||||
toolNotifySend
|
||||
toolKdialogPassive
|
||||
toolZenity
|
||||
toolKdialogError
|
||||
toolXmessage
|
||||
)
|
||||
|
||||
var (
|
||||
dialogTool dialogToolType
|
||||
dialogToolOnce sync.Once
|
||||
dialogTools map[string]bool
|
||||
dialogToolOnce sync.Once
|
||||
dialogToolNames = []string{"notify-send", "kdialog", "zenity", "xmessage"}
|
||||
)
|
||||
|
||||
func init() {
|
||||
dialogToolOnce.Do(detectDialogTool)
|
||||
dialogToolOnce.Do(func() { detectDialogTools(defaultCommandRunner{}) })
|
||||
}
|
||||
|
||||
func detectDialogTool() {
|
||||
tools := []struct {
|
||||
name string
|
||||
typ dialogToolType
|
||||
}{
|
||||
{"zenity", toolZenity},
|
||||
{"kdialog", toolKdialog},
|
||||
{"notify-send", toolNotifySend},
|
||||
{"xmessage", toolXmessage},
|
||||
func platformStartupChannels(runner commandRunner) []promptChannel {
|
||||
return []promptChannel{
|
||||
linuxCommandChannel("notify-send", toolNotifySend, runner, linuxHasGraphicalSessionAndDBus, func(req promptRequest) []string {
|
||||
return []string{"-u", "critical", "-a", appName, "-i", "nex", req.title, req.message}
|
||||
}),
|
||||
linuxCommandChannel("kdialog", toolKdialogPassive, runner, linuxHasGraphicalSession, func(req promptRequest) []string {
|
||||
return []string{"--title", req.title, "--passivepopup", req.message, "10"}
|
||||
}),
|
||||
linuxCommandChannel("zenity", toolZenity, runner, linuxHasGraphicalSession, func(req promptRequest) []string {
|
||||
return []string{"--error", fmt.Sprintf("--title=%s", req.title), fmt.Sprintf("--text=%s", req.message)}
|
||||
}),
|
||||
linuxCommandChannel("kdialog", toolKdialogError, runner, linuxHasGraphicalSession, func(req promptRequest) []string {
|
||||
return []string{"--title", req.title, "--error", req.message}
|
||||
}),
|
||||
linuxCommandChannel("xmessage", toolXmessage, runner, linuxHasX11Display, func(req promptRequest) []string {
|
||||
return []string{"-center", "-buttons", "OK:0", "-default", "OK", fmt.Sprintf("%s: %s", req.title, req.message)}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
for _, tool := range tools {
|
||||
if _, err := exec.LookPath(tool.name); err == nil {
|
||||
dialogTool = tool.typ
|
||||
return
|
||||
func detectDialogTools(runner commandRunner) {
|
||||
dialogTools = make(map[string]bool, len(dialogToolNames))
|
||||
for _, name := range dialogToolNames {
|
||||
_, err := runner.LookPath(name)
|
||||
dialogTools[name] = err == nil
|
||||
}
|
||||
}
|
||||
|
||||
func linuxCommandChannel(name string, typ dialogToolType, runner commandRunner, environmentOK func() error, args func(promptRequest) []string) promptChannel {
|
||||
return promptChannel{
|
||||
name: fmt.Sprintf("linux-%s-%d", name, typ),
|
||||
available: func() error {
|
||||
if err := linuxCommandAvailable(runner, name); err != nil {
|
||||
return err
|
||||
}
|
||||
return environmentOK()
|
||||
},
|
||||
run: func(req promptRequest) error {
|
||||
return runner.Run(promptCommandTimeout, nil, name, args(req)...)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func linuxCommandAvailable(runner commandRunner, name string) error {
|
||||
if _, ok := runner.(defaultCommandRunner); ok {
|
||||
dialogToolOnce.Do(func() { detectDialogTools(runner) })
|
||||
if dialogTools[name] {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%s 不可用", name)
|
||||
}
|
||||
|
||||
dialogTool = toolNone
|
||||
_, err := runner.LookPath(name)
|
||||
return err
|
||||
}
|
||||
|
||||
func showError(title, message string) {
|
||||
switch dialogTool {
|
||||
case toolZenity:
|
||||
exec.Command("zenity", "--error",
|
||||
fmt.Sprintf("--title=%s", title),
|
||||
fmt.Sprintf("--text=%s", message)).Run()
|
||||
case toolKdialog:
|
||||
exec.Command("kdialog", "--error", message, "--title", title).Run()
|
||||
case toolNotifySend:
|
||||
exec.Command("notify-send", "-u", "critical", title, message).Run()
|
||||
case toolXmessage:
|
||||
exec.Command("xmessage", "-center",
|
||||
fmt.Sprintf("%s: %s", title, message)).Run()
|
||||
default:
|
||||
dialogLogger().Error("无法显示错误对话框")
|
||||
func linuxHasGraphicalSession() error {
|
||||
if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" {
|
||||
return errors.New("缺少图形会话")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func linuxHasGraphicalSessionAndDBus() error {
|
||||
if err := linuxHasGraphicalSession(); err != nil {
|
||||
return err
|
||||
}
|
||||
if os.Getenv("DBUS_SESSION_BUS_ADDRESS") == "" {
|
||||
return errors.New("缺少 DBus session bus")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func linuxHasX11Display() error {
|
||||
if os.Getenv("DISPLAY") == "" {
|
||||
return errors.New("缺少 X11 DISPLAY")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
61
backend/cmd/desktop/dialog_linux_test.go
Normal file
61
backend/cmd/desktop/dialog_linux_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
//go:build linux
|
||||
|
||||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLinuxStartupChannelsPriorityAndArguments(t *testing.T) {
|
||||
t.Setenv("DISPLAY", ":0")
|
||||
t.Setenv("DBUS_SESSION_BUS_ADDRESS", "unix:path=/tmp/dbus")
|
||||
runner := &fakeCommandRunner{paths: map[string]bool{
|
||||
"notify-send": true,
|
||||
"kdialog": true,
|
||||
"zenity": true,
|
||||
"xmessage": true,
|
||||
}}
|
||||
|
||||
channels := platformStartupChannels(runner)
|
||||
if len(channels) != 5 {
|
||||
t.Fatalf("Linux 应有 5 个 UI 通道,实际: %d", len(channels))
|
||||
}
|
||||
|
||||
req := promptRequest{title: "Nex 启动失败", message: "端口被占用"}
|
||||
for _, channel := range channels {
|
||||
if err := channel.available(); err != nil {
|
||||
t.Fatalf("通道 %s 应可用: %v", channel.name, err)
|
||||
}
|
||||
if err := channel.run(req); err != nil {
|
||||
t.Fatalf("通道 %s 执行失败: %v", channel.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
wantNames := []string{"notify-send", "kdialog", "zenity", "kdialog", "xmessage"}
|
||||
for i, want := range wantNames {
|
||||
if got := runner.calls[i].name; got != want {
|
||||
t.Fatalf("第 %d 个命令 = %s, want %s", i, got, want)
|
||||
}
|
||||
}
|
||||
if got := runner.calls[0].args; len(got) < 2 || got[0] != "-u" || got[1] != "critical" {
|
||||
t.Fatalf("notify-send 应使用 critical 参数,实际: %#v", got)
|
||||
}
|
||||
if got := runner.calls[1].args; len(got) < 3 || got[2] != "--passivepopup" {
|
||||
t.Fatalf("kdialog 第一跳应使用 passivepopup,实际: %#v", got)
|
||||
}
|
||||
if got := runner.calls[2].args; len(got) < 1 || got[0] != "--error" {
|
||||
t.Fatalf("zenity 应使用 --error,实际: %#v", got)
|
||||
}
|
||||
if got := runner.calls[4].args; len(got) < 1 || got[0] != "-center" {
|
||||
t.Fatalf("xmessage 应居中显示,实际: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinuxNotifySendRequiresDBus(t *testing.T) {
|
||||
t.Setenv("DISPLAY", ":0")
|
||||
t.Setenv("DBUS_SESSION_BUS_ADDRESS", "")
|
||||
runner := &fakeCommandRunner{paths: map[string]bool{"notify-send": true}}
|
||||
|
||||
channels := platformStartupChannels(runner)
|
||||
if err := channels[0].available(); err == nil {
|
||||
t.Fatal("notify-send 缺少 DBus session bus 时应不可用")
|
||||
}
|
||||
}
|
||||
@@ -3,17 +3,21 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unicode/utf16"
|
||||
"unsafe"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
mbOK = 0x00000000
|
||||
mbIconError = 0x10
|
||||
mbIconInformation = 0x40
|
||||
mbTaskModal = 0x00002000
|
||||
mbSetForeground = 0x00010000
|
||||
mbTopMost = 0x00040000
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -25,12 +29,79 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func showError(title, message string) {
|
||||
if err := messageBox(title, message, mbIconError); err != nil {
|
||||
if zapLogger != nil {
|
||||
zapLogger.Warn("显示错误对话框失败", zap.Error(err))
|
||||
func platformStartupChannels(runner commandRunner) []promptChannel {
|
||||
return []promptChannel{
|
||||
{
|
||||
name: "windows-toast",
|
||||
available: func() error {
|
||||
_, err := findPowerShell(runner)
|
||||
return err
|
||||
},
|
||||
run: func(req promptRequest) error {
|
||||
name, err := findPowerShell(runner)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return runner.Run(promptCommandTimeout, []string{
|
||||
"NEX_TOAST_TITLE=" + req.title,
|
||||
"NEX_TOAST_BODY=" + req.message,
|
||||
}, name, "-NoProfile", "-NonInteractive", "-ExecutionPolicy", "Bypass", "-EncodedCommand", encodePowerShellCommand(windowsToastScript()))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "windows-messagebox",
|
||||
available: func() error {
|
||||
return messageBoxAvailable()
|
||||
},
|
||||
run: func(req promptRequest) error {
|
||||
return messageBox(req.title, req.message, messageBoxStartupFlags())
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func findPowerShell(runner commandRunner) (string, error) {
|
||||
for _, name := range []string{"powershell.exe", "powershell"} {
|
||||
if _, err := runner.LookPath(name); err == nil {
|
||||
return name, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("PowerShell 不可用")
|
||||
}
|
||||
|
||||
func windowsToastScript() string {
|
||||
return `$ErrorActionPreference = 'Stop'
|
||||
Add-Type -AssemblyName System.Runtime.WindowsRuntime
|
||||
$template = [Windows.UI.Notifications.ToastTemplateType]::ToastText02
|
||||
$xml = [Windows.UI.Notifications.ToastNotificationManager]::GetTemplateContent($template)
|
||||
$texts = $xml.GetElementsByTagName('text')
|
||||
$texts.Item(0).AppendChild($xml.CreateTextNode($env:NEX_TOAST_TITLE)) | Out-Null
|
||||
$texts.Item(1).AppendChild($xml.CreateTextNode($env:NEX_TOAST_BODY)) | Out-Null
|
||||
$toast = [Windows.UI.Notifications.ToastNotification]::new($xml)
|
||||
[Windows.UI.Notifications.ToastNotificationManager]::CreateToastNotifier('Nex').Show($toast)`
|
||||
}
|
||||
|
||||
func encodePowerShellCommand(script string) string {
|
||||
encoded := utf16.Encode([]rune(script))
|
||||
buf := make([]byte, 0, len(encoded)*2)
|
||||
for _, value := range encoded {
|
||||
buf = append(buf, byte(value), byte(value>>8))
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(buf)
|
||||
}
|
||||
|
||||
func messageBoxAvailable() error {
|
||||
if _, err := syscall.UTF16PtrFromString("Nex"); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := syscall.UTF16PtrFromString("test"); err != nil {
|
||||
return err
|
||||
}
|
||||
return procMessageBoxW.Find()
|
||||
}
|
||||
|
||||
func messageBoxStartupFlags() uint {
|
||||
return mbOK | mbIconError | mbTaskModal | mbSetForeground | mbTopMost
|
||||
}
|
||||
|
||||
func messageBox(title, message string, flags uint) error {
|
||||
|
||||
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
@@ -27,10 +28,10 @@ import (
|
||||
"nex/backend/internal/service"
|
||||
"nex/backend/pkg/buildinfo"
|
||||
|
||||
"github.com/getlantern/systray"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/flock"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
)
|
||||
@@ -40,31 +41,65 @@ var (
|
||||
zapLogger *zap.Logger
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
desktopHooks = defaultDesktopRuntimeHooks()
|
||||
)
|
||||
|
||||
type singletonLocker interface {
|
||||
Lock() error
|
||||
Unlock() error
|
||||
}
|
||||
|
||||
type desktopRuntimeHooks struct {
|
||||
loadConfig func() (*config.Config, config.ConfigMetadata, error)
|
||||
newLock func(string) singletonLocker
|
||||
listen func(int) (net.Listener, error)
|
||||
upgradeLogger func(*zap.Logger, pkgLogger.Config) (*zap.Logger, error)
|
||||
initDB func(*config.DatabaseConfig, *zap.Logger) (*gorm.DB, error)
|
||||
closeDB func(*gorm.DB)
|
||||
registerAdapters func(conversion.AdapterRegistry) error
|
||||
setupStaticFiles func(*gin.Engine) error
|
||||
startServer func(*http.Server, net.Listener, chan<- error, *zap.Logger)
|
||||
setupSystray func(int, <-chan error) error
|
||||
}
|
||||
|
||||
func defaultDesktopRuntimeHooks() desktopRuntimeHooks {
|
||||
return desktopRuntimeHooks{
|
||||
loadConfig: config.LoadDesktopConfigWithMetadata,
|
||||
newLock: func(lockPath string) singletonLocker { return NewSingletonLock(lockPath) },
|
||||
listen: listenDesktopPort,
|
||||
upgradeLogger: pkgLogger.Upgrade,
|
||||
initDB: database.Init,
|
||||
closeDB: database.Close,
|
||||
registerAdapters: registerDesktopAdapters,
|
||||
setupStaticFiles: setupStaticFiles,
|
||||
startServer: startDesktopServer,
|
||||
setupSystray: setupSystray,
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
minimalLogger := pkgLogger.NewMinimal()
|
||||
|
||||
cfg, cfgMeta, err := config.LoadDesktopConfigWithMetadata()
|
||||
if err != nil {
|
||||
minimalLogger.Error("加载配置失败", zap.Error(err))
|
||||
showError(appName, desktopConfigErrorMessage(getDesktopConfigPath(), err))
|
||||
if err := runDesktop(minimalLogger); err != nil {
|
||||
reportStartupFailure(err, dialogLogger())
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func runDesktop(minimalLogger *zap.Logger) error {
|
||||
if minimalLogger == nil {
|
||||
minimalLogger = pkgLogger.NewMinimal()
|
||||
}
|
||||
|
||||
cfg, cfgMeta, err := desktopHooks.loadConfig()
|
||||
if err != nil {
|
||||
return newStartupError(phaseConfig, desktopConfigErrorMessage(getDesktopConfigPath(), err), err)
|
||||
}
|
||||
|
||||
port := cfg.Server.Port
|
||||
|
||||
if err := checkPortAvailable(port); err != nil {
|
||||
minimalLogger.Error("端口不可用", zap.Error(err))
|
||||
showError(appName, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
singleLock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway.lock"))
|
||||
singleLock := desktopHooks.newLock(filepath.Join(os.TempDir(), "nex-gateway.lock"))
|
||||
if err := singleLock.Lock(); err != nil {
|
||||
minimalLogger.Error("已有 Nex 实例运行")
|
||||
showError(appName, "已有 Nex 实例运行")
|
||||
os.Exit(1)
|
||||
return newStartupError(phaseSingleton, "已有 Nex 实例运行", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := singleLock.Unlock(); err != nil {
|
||||
@@ -72,7 +107,13 @@ func main() {
|
||||
}
|
||||
}()
|
||||
|
||||
zapLogger, err = pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
|
||||
listener, err := desktopHooks.listen(port)
|
||||
if err != nil {
|
||||
return newStartupError(phasePort, desktopPortUnavailableMessage(port), err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
zapLogger, err = desktopHooks.upgradeLogger(minimalLogger, pkgLogger.Config{
|
||||
Level: cfg.Log.Level,
|
||||
Path: cfg.Log.Path,
|
||||
MaxSize: cfg.Log.MaxSize,
|
||||
@@ -81,7 +122,7 @@ func main() {
|
||||
Compress: cfg.Log.Compress,
|
||||
})
|
||||
if err != nil {
|
||||
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
|
||||
return newStartupError(phaseLogger, fmt.Sprintf("初始化日志失败\n\n日志目录: %s\n\n请检查目录权限或磁盘空间", cfg.Log.Path), err)
|
||||
}
|
||||
defer func() {
|
||||
if err := zapLogger.Sync(); err != nil {
|
||||
@@ -91,11 +132,17 @@ func main() {
|
||||
|
||||
cfg.PrintSummary(zapLogger)
|
||||
|
||||
db, err := database.Init(&cfg.Database, zapLogger)
|
||||
db, err := desktopHooks.initDB(&cfg.Database, zapLogger)
|
||||
if err != nil {
|
||||
zapLogger.Fatal("初始化数据库失败", zap.Error(err))
|
||||
phase := phaseDatabase
|
||||
message := fmt.Sprintf("数据库初始化失败\n\n请检查数据库配置、文件权限或连接状态\n\n%v", err)
|
||||
if errors.Is(err, database.ErrMigration) {
|
||||
phase = phaseMigration
|
||||
message = fmt.Sprintf("数据库迁移失败\n\n请查看日志或检查数据库迁移权限\n\n%v", err)
|
||||
}
|
||||
return newStartupError(phase, message, err)
|
||||
}
|
||||
defer database.Close(db)
|
||||
defer desktopHooks.closeDB(db)
|
||||
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
@@ -118,11 +165,8 @@ func main() {
|
||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
|
||||
}
|
||||
if err := registry.Register(anthropic.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
|
||||
if err := desktopHooks.registerAdapters(registry); err != nil {
|
||||
return newStartupError(phaseAdapter, startupInternalErrorMessage(), err)
|
||||
}
|
||||
engine := conversion.NewConversionEngine(registry, zapLogger)
|
||||
|
||||
@@ -144,7 +188,9 @@ func main() {
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler, versionHandler, settingsHandler)
|
||||
setupStaticFiles(r)
|
||||
if err := desktopHooks.setupStaticFiles(r); err != nil {
|
||||
return newStartupError(phaseStaticResource, startupInternalErrorMessage(), err)
|
||||
}
|
||||
|
||||
server = &http.Server{
|
||||
Addr: desktopListenAddr(port),
|
||||
@@ -154,26 +200,46 @@ func main() {
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownCancel = context.WithCancel(context.Background())
|
||||
defer doShutdown()
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
desktopHooks.startServer(server, listener, serverErrCh, zapLogger)
|
||||
select {
|
||||
case err := <-serverErrCh:
|
||||
return newStartupError(phaseServer, startupServerErrorMessage(), err)
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
if err := desktopHooks.setupSystray(port, serverErrCh); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-serverErrCh:
|
||||
return newStartupError(phaseServer, startupServerErrorMessage(), err)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func registerDesktopAdapters(registry conversion.AdapterRegistry) error {
|
||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||
return err
|
||||
}
|
||||
return registry.Register(anthropic.NewAdapter())
|
||||
}
|
||||
|
||||
func startDesktopServer(server *http.Server, listener net.Listener, serverErrCh chan<- error, logger *zap.Logger) {
|
||||
go func() {
|
||||
zapLogger.Info("AI Gateway 启动",
|
||||
logger.Info("AI Gateway 启动",
|
||||
zap.String("addr", server.Addr),
|
||||
zap.String("version", buildinfo.Version()),
|
||||
zap.String("commit", buildinfo.Commit()),
|
||||
zap.String("build_time", buildinfo.BuildTime()))
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
zapLogger.Fatal("服务器启动失败", zap.Error(err))
|
||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
serverErrCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := openBrowser(desktopURL(port)); err != nil {
|
||||
zapLogger.Warn("无法打开浏览器", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
setupSystray(port)
|
||||
}
|
||||
|
||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler, versionHandler *handler.VersionHandler, settingsHandler *handler.SettingsHandler) {
|
||||
@@ -223,12 +289,13 @@ func withProtocol(protocol string, next gin.HandlerFunc) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func setupStaticFiles(r *gin.Engine) {
|
||||
func setupStaticFiles(r *gin.Engine) error {
|
||||
distFS, err := frontendDistFS()
|
||||
if err != nil {
|
||||
zapLogger.Fatal("无法加载前端资源", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
setupStaticFilesWithFS(r, distFS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func frontendDistFS() (fs.FS, error) {
|
||||
@@ -299,47 +366,6 @@ func setupStaticFilesWithFS(r *gin.Engine, distFS fs.FS) {
|
||||
})
|
||||
}
|
||||
|
||||
func setupSystray(port int) {
|
||||
systray.Run(func() {
|
||||
var icon []byte
|
||||
var err error
|
||||
if runtime.GOOS == "windows" {
|
||||
icon, err = embedfs.Assets.ReadFile("assets/icon.ico")
|
||||
} else {
|
||||
icon, err = embedfs.Assets.ReadFile("assets/icon.png")
|
||||
}
|
||||
if err != nil {
|
||||
zapLogger.Error("无法加载托盘图标", zap.Error(err))
|
||||
}
|
||||
systray.SetIcon(icon)
|
||||
systray.SetTooltip(appTooltip)
|
||||
|
||||
mOpen := systray.AddMenuItem("打开管理界面", "在浏览器中打开")
|
||||
systray.AddSeparator()
|
||||
mStatus := systray.AddMenuItem("状态: 运行中", "")
|
||||
mStatus.Disable()
|
||||
mPort := systray.AddMenuItem(desktopPortMenuTitle(port), "")
|
||||
mPort.Disable()
|
||||
systray.AddSeparator()
|
||||
mQuit := systray.AddMenuItem("退出", "停止服务并退出")
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-mOpen.ClickedCh:
|
||||
if err := openBrowser(desktopURL(port)); err != nil {
|
||||
zapLogger.Warn("打开浏览器失败", zap.Error(err))
|
||||
}
|
||||
case <-mQuit.ClickedCh:
|
||||
doShutdown()
|
||||
systray.Quit()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func doShutdown() {
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("正在关闭服务器...")
|
||||
@@ -382,13 +408,12 @@ func desktopPortMenuTitle(port int) string {
|
||||
return fmt.Sprintf("端口: %d", port)
|
||||
}
|
||||
|
||||
func checkPortAvailable(port int) error {
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("端口 %d 已被占用\n\n可能原因:\n- 已有 Nex 实例运行\n- 其他程序占用了该端口\n\n请检查并关闭占用端口的程序", port)
|
||||
}
|
||||
ln.Close()
|
||||
return nil
|
||||
func listenDesktopPort(port int) (net.Listener, error) {
|
||||
return net.Listen("tcp", desktopListenAddr(port))
|
||||
}
|
||||
|
||||
func desktopPortUnavailableMessage(port int) string {
|
||||
return fmt.Sprintf("端口 %d 已被占用\n\n可能原因:\n- 已有 Nex 实例运行\n- 其他程序占用了该端口\n\n请检查并关闭占用端口的程序", port)
|
||||
}
|
||||
|
||||
type SingletonLock struct {
|
||||
|
||||
@@ -47,9 +47,15 @@ func TestMessageBoxW_WindowsOnly_FailureUsesReturnValue(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestShowError_WindowsBranch(t *testing.T) {
|
||||
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
|
||||
return 0, syscall.Errno(5)
|
||||
})
|
||||
old := buildPromptChannels
|
||||
buildPromptChannels = func(commandRunner) []promptChannel {
|
||||
return []promptChannel{{
|
||||
name: "fake-failed-channel",
|
||||
available: func() error { return nil },
|
||||
run: func(promptRequest) error { return syscall.Errno(5) },
|
||||
}}
|
||||
}
|
||||
t.Cleanup(func() { buildPromptChannels = old })
|
||||
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
@@ -59,3 +65,42 @@ func TestShowError_WindowsBranch(t *testing.T) {
|
||||
|
||||
showError("测试错误", "这是一条测试错误消息")
|
||||
}
|
||||
|
||||
func TestMessageBoxW_WindowsOnly_StartupFlags(t *testing.T) {
|
||||
var gotFlags uintptr
|
||||
withMessageBoxW(t, func(_, _, _, flags uintptr) (uintptr, error) {
|
||||
gotFlags = flags
|
||||
return 1, syscall.Errno(0)
|
||||
})
|
||||
|
||||
if err := messageBox("测试标题", "测试消息", messageBoxStartupFlags()); err != nil {
|
||||
t.Fatalf("MessageBoxW 应成功: %v", err)
|
||||
}
|
||||
|
||||
for _, flag := range []uint{mbIconError, mbTaskModal, mbSetForeground, mbTopMost} {
|
||||
if gotFlags&uintptr(flag) == 0 {
|
||||
t.Fatalf("startup flags 缺少 0x%x,实际: 0x%x", flag, gotFlags)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsStartupChannelsUseToastBeforeMessageBox(t *testing.T) {
|
||||
runner := &fakeCommandRunner{paths: map[string]bool{"powershell.exe": true}}
|
||||
channels := platformStartupChannels(runner)
|
||||
if len(channels) != 2 {
|
||||
t.Fatalf("Windows 应有 Toast 和 MessageBox 两级通道,实际: %d", len(channels))
|
||||
}
|
||||
|
||||
if channels[0].name != "windows-toast" || channels[1].name != "windows-messagebox" {
|
||||
t.Fatalf("Windows 通道顺序错误: %s, %s", channels[0].name, channels[1].name)
|
||||
}
|
||||
if err := channels[0].available(); err != nil {
|
||||
t.Fatalf("PowerShell 存在时 Toast 通道应可用: %v", err)
|
||||
}
|
||||
if err := channels[0].run(promptRequest{title: "Nex 启动失败", message: "端口被占用"}); err != nil {
|
||||
t.Fatalf("Toast fake runner 应执行成功: %v", err)
|
||||
}
|
||||
if len(runner.calls) != 1 || runner.calls[0].name != "powershell.exe" {
|
||||
t.Fatalf("Toast 应调用 powershell.exe,实际: %#v", runner.calls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,87 +9,27 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCheckPortAvailable(t *testing.T) {
|
||||
port := 19826
|
||||
|
||||
err := checkPortAvailable(port)
|
||||
func TestListenDesktopPortReturnsReusableListener(t *testing.T) {
|
||||
listener, err := listenDesktopPort(0)
|
||||
if err != nil {
|
||||
t.Fatalf("端口 %d 应该可用: %v", port, err)
|
||||
}
|
||||
|
||||
t.Log("端口可用测试通过")
|
||||
}
|
||||
|
||||
func TestCheckPortOccupied(t *testing.T) {
|
||||
port := 19827
|
||||
|
||||
listener, err := net.Listen("tcp", ":19827") //nolint:gosec // 需要验证 checkPortAvailable 对通配地址占用的检测行为
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
t.Fatalf("listener-first 应直接获取配置端口 listener: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = checkPortAvailable(port)
|
||||
if err == nil {
|
||||
t.Fatal("端口被占用时应该返回错误")
|
||||
}
|
||||
|
||||
t.Log("端口占用检测测试通过")
|
||||
}
|
||||
|
||||
func TestCheckPortAvailableAfterClose(t *testing.T) {
|
||||
port := 19828
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:19828")
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
|
||||
server := &http.Server{ReadHeaderTimeout: time.Second}
|
||||
defer server.Close()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
err := server.Serve(listener)
|
||||
if err != nil && err != http.ErrServerClosed && !errors.Is(err, net.ErrClosed) {
|
||||
t.Errorf("serve failed: %v", err)
|
||||
t.Errorf("使用同一个 listener 启动 server 失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
listener.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = checkPortAvailable(port)
|
||||
if err != nil {
|
||||
t.Fatalf("端口关闭后应该可用: %v", err)
|
||||
if err := server.Close(); err != nil {
|
||||
t.Fatalf("关闭测试 server 失败: %v", err)
|
||||
}
|
||||
|
||||
t.Log("端口关闭后可用测试通过")
|
||||
}
|
||||
|
||||
func TestCheckPortAvailableErrorContainsPort(t *testing.T) {
|
||||
port := 19829
|
||||
|
||||
listener, err := net.Listen("tcp", ":19829") //nolint:gosec
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = checkPortAvailable(port)
|
||||
if err == nil {
|
||||
t.Fatal("端口被占用时应该返回错误")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "19829") {
|
||||
t.Fatalf("错误信息应包含端口号 19829,实际: %v", err)
|
||||
}
|
||||
|
||||
t.Log("端口错误信息包含端口号测试通过")
|
||||
<-done
|
||||
}
|
||||
|
||||
func TestGetDesktopConfigPath(t *testing.T) {
|
||||
|
||||
121
backend/cmd/desktop/reporter.go
Normal file
121
backend/cmd/desktop/reporter.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const promptCommandTimeout = 5 * time.Second
|
||||
|
||||
type promptRequest struct {
|
||||
title string
|
||||
message string
|
||||
subtitle string
|
||||
}
|
||||
|
||||
type promptChannel struct {
|
||||
name string
|
||||
available func() error
|
||||
run func(promptRequest) error
|
||||
}
|
||||
|
||||
type commandRunner interface {
|
||||
LookPath(file string) (string, error)
|
||||
Run(timeout time.Duration, env []string, name string, args ...string) error
|
||||
}
|
||||
|
||||
type defaultCommandRunner struct{}
|
||||
|
||||
var buildPromptChannels = platformStartupChannels
|
||||
|
||||
func (defaultCommandRunner) LookPath(file string) (string, error) {
|
||||
return exec.LookPath(file)
|
||||
}
|
||||
|
||||
func (defaultCommandRunner) Run(timeout time.Duration, env []string, name string, args ...string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, name, args...)
|
||||
if len(env) > 0 {
|
||||
cmd.Env = append(os.Environ(), env...)
|
||||
}
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func showError(title, message string) {
|
||||
reportPrompt(promptRequest{title: title, message: message}, os.Stderr, dialogLogger())
|
||||
}
|
||||
|
||||
func reportStartupFailure(err error, logger *zap.Logger) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var startupErr *startupError
|
||||
if !errors.As(err, &startupErr) {
|
||||
startupErr = newStartupError(phaseServer, startupServerErrorMessage(), err)
|
||||
}
|
||||
|
||||
if logger == nil {
|
||||
logger = dialogLogger()
|
||||
}
|
||||
logger.Error("desktop 启动失败",
|
||||
zap.String("phase", startupErr.Phase()),
|
||||
zap.Error(startupErr))
|
||||
|
||||
reportPrompt(promptRequest{
|
||||
title: startupTitle(),
|
||||
message: startupErr.UserMessage(),
|
||||
subtitle: startupErr.Phase(),
|
||||
}, os.Stderr, logger)
|
||||
}
|
||||
|
||||
func reportPrompt(req promptRequest, fallback io.Writer, logger *zap.Logger) {
|
||||
runPromptPipeline(req, buildPromptChannels(defaultCommandRunner{}), fallback, logger)
|
||||
}
|
||||
|
||||
func runPromptPipeline(req promptRequest, channels []promptChannel, fallback io.Writer, logger *zap.Logger) {
|
||||
if logger == nil {
|
||||
logger = dialogLogger()
|
||||
}
|
||||
|
||||
for _, channel := range channels {
|
||||
if channel.available != nil {
|
||||
if err := channel.available(); err != nil {
|
||||
logger.Warn("提示通道不可用", zap.String("channel", channel.name), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := channel.run(req); err != nil {
|
||||
logger.Warn("提示通道执行失败", zap.String("channel", channel.name), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
writePromptFallback(fallback, req.title, req.message)
|
||||
}
|
||||
|
||||
func writePromptFallback(w io.Writer, title, message string) {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
if _, err := io.WriteString(w, "错误: "+title+": "+message+"\n"); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
140
backend/cmd/desktop/reporter_test.go
Normal file
140
backend/cmd/desktop/reporter_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zaptest/observer"
|
||||
)
|
||||
|
||||
type commandCall struct {
|
||||
timeout time.Duration
|
||||
env []string
|
||||
name string
|
||||
args []string
|
||||
}
|
||||
|
||||
type fakeCommandRunner struct {
|
||||
paths map[string]bool
|
||||
runErrs map[string]error
|
||||
calls []commandCall
|
||||
}
|
||||
|
||||
func (r *fakeCommandRunner) LookPath(file string) (string, error) {
|
||||
if r.paths[file] {
|
||||
return "/usr/bin/" + file, nil
|
||||
}
|
||||
return "", exec.ErrNotFound
|
||||
}
|
||||
|
||||
func (r *fakeCommandRunner) Run(timeout time.Duration, env []string, name string, args ...string) error {
|
||||
r.calls = append(r.calls, commandCall{
|
||||
timeout: timeout,
|
||||
env: append([]string(nil), env...),
|
||||
name: name,
|
||||
args: append([]string(nil), args...),
|
||||
})
|
||||
if err := r.runErrs[name]; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRunPromptPipelineFallbackOrder(t *testing.T) {
|
||||
var calls []string
|
||||
channels := []promptChannel{
|
||||
{
|
||||
name: "unavailable",
|
||||
available: func() error {
|
||||
calls = append(calls, "available-1")
|
||||
return errors.New("missing")
|
||||
},
|
||||
run: func(promptRequest) error {
|
||||
calls = append(calls, "run-1")
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "failed",
|
||||
available: func() error {
|
||||
calls = append(calls, "available-2")
|
||||
return nil
|
||||
},
|
||||
run: func(promptRequest) error {
|
||||
calls = append(calls, "run-2")
|
||||
return errors.New("failed")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
available: func() error {
|
||||
calls = append(calls, "available-3")
|
||||
return nil
|
||||
},
|
||||
run: func(promptRequest) error {
|
||||
calls = append(calls, "run-3")
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var fallback bytes.Buffer
|
||||
runPromptPipeline(promptRequest{title: "Nex 启动失败", message: "启动失败"}, channels, &fallback, zap.NewNop())
|
||||
|
||||
want := []string{"available-1", "available-2", "run-2", "available-3", "run-3"}
|
||||
if fmt.Sprint(calls) != fmt.Sprint(want) {
|
||||
t.Fatalf("调用顺序 = %v, want %v", calls, want)
|
||||
}
|
||||
if fallback.Len() != 0 {
|
||||
t.Fatalf("成功通道后不应写入 fallback,实际: %s", fallback.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunPromptPipelineWritesFallback(t *testing.T) {
|
||||
channels := []promptChannel{
|
||||
{
|
||||
name: "unavailable",
|
||||
available: func() error { return errors.New("missing") },
|
||||
run: func(promptRequest) error { return nil },
|
||||
},
|
||||
}
|
||||
|
||||
var fallback bytes.Buffer
|
||||
runPromptPipeline(promptRequest{title: "Nex 启动失败", message: "端口被占用"}, channels, &fallback, zap.NewNop())
|
||||
|
||||
want := "错误: Nex 启动失败: 端口被占用\n"
|
||||
if fallback.String() != want {
|
||||
t.Fatalf("fallback = %q, want %q", fallback.String(), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReportStartupFailureLogsRedactedError(t *testing.T) {
|
||||
old := buildPromptChannels
|
||||
buildPromptChannels = func(commandRunner) []promptChannel {
|
||||
return []promptChannel{{name: "fake-success", run: func(promptRequest) error { return nil }}}
|
||||
}
|
||||
t.Cleanup(func() { buildPromptChannels = old })
|
||||
|
||||
core, logs := observer.New(zap.ErrorLevel)
|
||||
logger := zap.New(core)
|
||||
err := errors.New("数据库连接失败: nex:secret@tcp(localhost:3306)/nex password=secret api_key=sk-test")
|
||||
|
||||
reportStartupFailure(err, logger)
|
||||
|
||||
entries := logs.All()
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("应记录 1 条错误日志,实际: %d", len(entries))
|
||||
}
|
||||
fields := fmt.Sprint(entries[0].ContextMap())
|
||||
for _, secret := range []string{"secret", "sk-test"} {
|
||||
if strings.Contains(fields, secret) {
|
||||
t.Fatalf("启动失败日志不应包含敏感信息 %q,实际: %s", secret, fields)
|
||||
}
|
||||
}
|
||||
}
|
||||
332
backend/cmd/desktop/run_desktop_test.go
Normal file
332
backend/cmd/desktop/run_desktop_test.go
Normal file
@@ -0,0 +1,332 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/database"
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type fakeDesktopLock struct {
|
||||
lockErr error
|
||||
unlockCount atomic.Int32
|
||||
}
|
||||
|
||||
func (l *fakeDesktopLock) Lock() error {
|
||||
return l.lockErr
|
||||
}
|
||||
|
||||
func (l *fakeDesktopLock) Unlock() error {
|
||||
l.unlockCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *fakeDesktopLock) unlocked() bool {
|
||||
return l.unlockCount.Load() > 0
|
||||
}
|
||||
|
||||
type recordingListener struct {
|
||||
net.Listener
|
||||
closeCount atomic.Int32
|
||||
}
|
||||
|
||||
func newRecordingListener(t *testing.T) *recordingListener {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建测试 listener 失败: %v", err)
|
||||
}
|
||||
return &recordingListener{Listener: listener}
|
||||
}
|
||||
|
||||
func (l *recordingListener) Close() error {
|
||||
l.closeCount.Add(1)
|
||||
return l.Listener.Close()
|
||||
}
|
||||
|
||||
func (l *recordingListener) closed() bool {
|
||||
return l.closeCount.Load() > 0
|
||||
}
|
||||
|
||||
func testDesktopConfig(t *testing.T) *config.Config {
|
||||
t.Helper()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Server.Port = 0
|
||||
cfg.Database.Driver = "sqlite"
|
||||
cfg.Database.Path = filepath.Join(tmpDir, "config.db")
|
||||
cfg.Log.Path = filepath.Join(tmpDir, "log")
|
||||
return cfg
|
||||
}
|
||||
|
||||
func installDesktopTestHooks(t *testing.T, cfg *config.Config, mutate func(*desktopRuntimeHooks)) {
|
||||
t.Helper()
|
||||
|
||||
oldHooks := desktopHooks
|
||||
oldServer := server
|
||||
oldLogger := zapLogger
|
||||
oldShutdownCtx := shutdownCtx
|
||||
oldShutdownCancel := shutdownCancel
|
||||
|
||||
server = nil
|
||||
zapLogger = nil
|
||||
shutdownCtx = nil
|
||||
shutdownCancel = nil
|
||||
|
||||
hooks := defaultDesktopRuntimeHooks()
|
||||
if cfg != nil {
|
||||
hooks.loadConfig = func() (*config.Config, config.ConfigMetadata, error) {
|
||||
return cfg, config.ConfigMetadata{ConfigPath: filepath.Join(t.TempDir(), "config.yaml")}, nil
|
||||
}
|
||||
}
|
||||
hooks.upgradeLogger = func(_ *zap.Logger, _ pkgLogger.Config) (*zap.Logger, error) {
|
||||
return zap.NewNop(), nil
|
||||
}
|
||||
hooks.setupStaticFiles = func(*gin.Engine) error { return nil }
|
||||
hooks.startServer = func(*http.Server, net.Listener, chan<- error, *zap.Logger) {}
|
||||
hooks.setupSystray = func(int, <-chan error) error { return nil }
|
||||
|
||||
if mutate != nil {
|
||||
mutate(&hooks)
|
||||
}
|
||||
desktopHooks = hooks
|
||||
|
||||
t.Cleanup(func() {
|
||||
if server != nil {
|
||||
_ = server.Close()
|
||||
}
|
||||
desktopHooks = oldHooks
|
||||
server = oldServer
|
||||
zapLogger = oldLogger
|
||||
shutdownCtx = oldShutdownCtx
|
||||
shutdownCancel = oldShutdownCancel
|
||||
})
|
||||
}
|
||||
|
||||
func requireStartupPhase(t *testing.T, err error, want startupPhase) {
|
||||
t.Helper()
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("期望 %s 阶段启动错误,实际 nil", want)
|
||||
}
|
||||
var startupErr *startupError
|
||||
if !errors.As(err, &startupErr) {
|
||||
t.Fatalf("期望 startupError,实际: %T %v", err, err)
|
||||
}
|
||||
if startupErr.phase != want {
|
||||
t.Fatalf("phase = %s, want %s", startupErr.phase, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunDesktopConfigFailureReturnsConfigPhase(t *testing.T) {
|
||||
installDesktopTestHooks(t, nil, func(h *desktopRuntimeHooks) {
|
||||
h.loadConfig = func() (*config.Config, config.ConfigMetadata, error) {
|
||||
return nil, config.ConfigMetadata{}, errors.New("yaml 解析失败")
|
||||
}
|
||||
})
|
||||
|
||||
err := runDesktop(zap.NewNop())
|
||||
requireStartupPhase(t, err, phaseConfig)
|
||||
}
|
||||
|
||||
func TestRunDesktopSingletonFailurePrecedesPortListen(t *testing.T) {
|
||||
cfg := testDesktopConfig(t)
|
||||
lock := &fakeDesktopLock{lockErr: errors.New("已有实例运行")}
|
||||
listenCalled := false
|
||||
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
|
||||
h.newLock = func(string) singletonLocker { return lock }
|
||||
h.listen = func(int) (net.Listener, error) {
|
||||
listenCalled = true
|
||||
return nil, errors.New("不应监听端口")
|
||||
}
|
||||
})
|
||||
|
||||
err := runDesktop(zap.NewNop())
|
||||
requireStartupPhase(t, err, phaseSingleton)
|
||||
if listenCalled {
|
||||
t.Fatal("单实例锁失败时不应继续监听端口")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunDesktopPortFailureUnlocksSingleton(t *testing.T) {
|
||||
cfg := testDesktopConfig(t)
|
||||
lock := &fakeDesktopLock{}
|
||||
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
|
||||
h.newLock = func(string) singletonLocker { return lock }
|
||||
h.listen = func(int) (net.Listener, error) { return nil, errors.New("bind failed") }
|
||||
})
|
||||
|
||||
err := runDesktop(zap.NewNop())
|
||||
requireStartupPhase(t, err, phasePort)
|
||||
if !lock.unlocked() {
|
||||
t.Fatal("端口监听失败时应释放单实例锁")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunDesktopLoggerFailureClosesListenerAndUnlocks(t *testing.T) {
|
||||
cfg := testDesktopConfig(t)
|
||||
lock := &fakeDesktopLock{}
|
||||
listener := newRecordingListener(t)
|
||||
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
|
||||
h.newLock = func(string) singletonLocker { return lock }
|
||||
h.listen = func(int) (net.Listener, error) { return listener, nil }
|
||||
h.upgradeLogger = func(*zap.Logger, pkgLogger.Config) (*zap.Logger, error) {
|
||||
return nil, errors.New("log permission denied")
|
||||
}
|
||||
})
|
||||
|
||||
err := runDesktop(zap.NewNop())
|
||||
requireStartupPhase(t, err, phaseLogger)
|
||||
if !listener.closed() {
|
||||
t.Fatal("日志初始化失败时应关闭 listener")
|
||||
}
|
||||
if !lock.unlocked() {
|
||||
t.Fatal("日志初始化失败时应释放单实例锁")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunDesktopDatabaseFailureClassification(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want startupPhase
|
||||
}{
|
||||
{name: "database", err: errors.New("open failed"), want: phaseDatabase},
|
||||
{name: "migration", err: fmt.Errorf("%w: %w", database.ErrMigration, errors.New("goose failed")), want: phaseMigration},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := testDesktopConfig(t)
|
||||
lock := &fakeDesktopLock{}
|
||||
listener := newRecordingListener(t)
|
||||
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
|
||||
h.newLock = func(string) singletonLocker { return lock }
|
||||
h.listen = func(int) (net.Listener, error) { return listener, nil }
|
||||
h.initDB = func(*config.DatabaseConfig, *zap.Logger) (*gorm.DB, error) { return nil, tt.err }
|
||||
})
|
||||
|
||||
err := runDesktop(zap.NewNop())
|
||||
requireStartupPhase(t, err, tt.want)
|
||||
if !listener.closed() {
|
||||
t.Fatal("数据库失败时应关闭 listener")
|
||||
}
|
||||
if !lock.unlocked() {
|
||||
t.Fatal("数据库失败时应释放单实例锁")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunDesktopInternalStartupFailurePhasesAndDatabaseCleanup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(*desktopRuntimeHooks)
|
||||
want startupPhase
|
||||
}{
|
||||
{
|
||||
name: "adapter",
|
||||
mutate: func(h *desktopRuntimeHooks) {
|
||||
h.registerAdapters = func(conversion.AdapterRegistry) error { return errors.New("adapter failed") }
|
||||
},
|
||||
want: phaseAdapter,
|
||||
},
|
||||
{
|
||||
name: "static",
|
||||
mutate: func(h *desktopRuntimeHooks) {
|
||||
h.setupStaticFiles = func(*gin.Engine) error { return errors.New("missing frontend") }
|
||||
},
|
||||
want: phaseStaticResource,
|
||||
},
|
||||
{
|
||||
name: "server",
|
||||
mutate: func(h *desktopRuntimeHooks) {
|
||||
h.startServer = func(_ *http.Server, _ net.Listener, errCh chan<- error, _ *zap.Logger) {
|
||||
errCh <- errors.New("serve failed")
|
||||
}
|
||||
},
|
||||
want: phaseServer,
|
||||
},
|
||||
{
|
||||
name: "tray",
|
||||
mutate: func(h *desktopRuntimeHooks) {
|
||||
h.setupSystray = func(int, <-chan error) error {
|
||||
return newStartupError(phaseTray, "托盘初始化失败", errors.New("tray failed"))
|
||||
}
|
||||
},
|
||||
want: phaseTray,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := testDesktopConfig(t)
|
||||
lock := &fakeDesktopLock{}
|
||||
listener := newRecordingListener(t)
|
||||
closeDBCalled := false
|
||||
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
|
||||
h.newLock = func(string) singletonLocker { return lock }
|
||||
h.listen = func(int) (net.Listener, error) { return listener, nil }
|
||||
h.closeDB = func(db *gorm.DB) {
|
||||
closeDBCalled = true
|
||||
database.Close(db)
|
||||
}
|
||||
tt.mutate(h)
|
||||
})
|
||||
|
||||
err := runDesktop(zap.NewNop())
|
||||
requireStartupPhase(t, err, tt.want)
|
||||
if !closeDBCalled {
|
||||
t.Fatal("数据库初始化后的启动失败应关闭数据库")
|
||||
}
|
||||
if !listener.closed() {
|
||||
t.Fatal("数据库初始化后的启动失败应关闭 listener")
|
||||
}
|
||||
if !lock.unlocked() {
|
||||
t.Fatal("数据库初始化后的启动失败应释放单实例锁")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunDesktopBrowserFailureRemainsNonFatal(t *testing.T) {
|
||||
controller := newFakeTrayController()
|
||||
notified := make(chan string, 1)
|
||||
controller.run = func(onReady func(), _ func()) {
|
||||
onReady()
|
||||
<-controller.quitCh
|
||||
}
|
||||
|
||||
err := runSystray(19826, trayOptions{
|
||||
controller: controller,
|
||||
readyTimeout: time.Second,
|
||||
iconLoader: func() ([]byte, error) { return []byte("icon"), nil },
|
||||
openBrowser: func(string) error { return errors.New("no browser") },
|
||||
notify: func(_, message string) {
|
||||
notified <- message
|
||||
controller.Quit()
|
||||
},
|
||||
logger: zap.NewNop(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("浏览器打开失败不应导致 runSystray 返回 fatal: %v", err)
|
||||
}
|
||||
if got := <-notified; got == "" {
|
||||
t.Fatal("浏览器打开失败应提示用户手动访问")
|
||||
}
|
||||
}
|
||||
96
backend/cmd/desktop/startup_error.go
Normal file
96
backend/cmd/desktop/startup_error.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
type startupPhase string
|
||||
|
||||
const (
|
||||
phaseConfig startupPhase = "config"
|
||||
phaseSingleton startupPhase = "singleton"
|
||||
phasePort startupPhase = "port"
|
||||
phaseLogger startupPhase = "logger"
|
||||
phaseDatabase startupPhase = "database"
|
||||
phaseMigration startupPhase = "migration"
|
||||
phaseAdapter startupPhase = "adapter"
|
||||
phaseStaticResource startupPhase = "static"
|
||||
phaseServer startupPhase = "server"
|
||||
phaseTray startupPhase = "tray"
|
||||
)
|
||||
|
||||
type startupError struct {
|
||||
phase startupPhase
|
||||
message string
|
||||
cause error
|
||||
}
|
||||
|
||||
func newStartupError(phase startupPhase, message string, cause error) *startupError {
|
||||
return &startupError{
|
||||
phase: phase,
|
||||
message: redactSensitive(message),
|
||||
cause: cause,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *startupError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.cause == nil {
|
||||
return fmt.Sprintf("%s: %s", e.phase, e.message)
|
||||
}
|
||||
return fmt.Sprintf("%s: %s: %s", e.phase, e.message, redactSensitive(e.cause.Error()))
|
||||
}
|
||||
|
||||
func (e *startupError) Unwrap() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.cause
|
||||
}
|
||||
|
||||
func (e *startupError) Phase() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return string(e.phase)
|
||||
}
|
||||
|
||||
func (e *startupError) UserMessage() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return redactSensitive(e.message)
|
||||
}
|
||||
|
||||
var sensitiveReplacers = []struct {
|
||||
pattern *regexp.Regexp
|
||||
replacement string
|
||||
}{
|
||||
{regexp.MustCompile(`(?i)(password\s*[:=]\s*)[^\s,;&]+`), `${1}<redacted>`},
|
||||
{regexp.MustCompile(`(?i)(api[_-]?key\s*[:=]\s*)[^\s,;&]+`), `${1}<redacted>`},
|
||||
{regexp.MustCompile(`(?i)(secret\s*[:=]\s*)[^\s,;&]+`), `${1}<redacted>`},
|
||||
{regexp.MustCompile(`([^\s:/]+):([^\s@]+)@tcp\(`), `${1}:<redacted>@tcp(`},
|
||||
{regexp.MustCompile(`(://[^\s:/]+):([^\s@]+)@`), `${1}:<redacted>@`},
|
||||
}
|
||||
|
||||
func redactSensitive(s string) string {
|
||||
for _, replacer := range sensitiveReplacers {
|
||||
s = replacer.pattern.ReplaceAllString(s, replacer.replacement)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func startupTitle() string {
|
||||
return appName + " 启动失败"
|
||||
}
|
||||
|
||||
func startupServerErrorMessage() string {
|
||||
return "后端服务启动失败\n\n请检查端口占用、网络权限或查看日志获取更多信息"
|
||||
}
|
||||
|
||||
func startupInternalErrorMessage() string {
|
||||
return "应用初始化失败\n\n请查看日志或重新安装应用"
|
||||
}
|
||||
40
backend/cmd/desktop/startup_error_test.go
Normal file
40
backend/cmd/desktop/startup_error_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStartupErrorContainsPhaseAndCause(t *testing.T) {
|
||||
cause := errors.New("底层失败")
|
||||
err := newStartupError(phaseDatabase, "数据库初始化失败", cause)
|
||||
|
||||
if err.Phase() != "database" {
|
||||
t.Fatalf("phase = %q, want database", err.Phase())
|
||||
}
|
||||
if !errors.Is(err, cause) {
|
||||
t.Fatal("startupError 应保留底层 cause")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "database") {
|
||||
t.Fatalf("错误字符串应包含 phase,实际: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartupErrorRedactsSensitiveUserMessage(t *testing.T) {
|
||||
message := "数据库初始化失败: nex:secret@tcp(localhost:3306)/nex password=secret api_key=sk-test"
|
||||
err := newStartupError(phaseDatabase, message, errors.New("cause password=secret api_key=sk-test"))
|
||||
userMessage := err.UserMessage()
|
||||
|
||||
for _, secret := range []string{"secret", "sk-test"} {
|
||||
if strings.Contains(userMessage, secret) {
|
||||
t.Fatalf("用户提示不应包含敏感信息 %q,实际: %s", secret, userMessage)
|
||||
}
|
||||
if strings.Contains(err.Error(), secret) {
|
||||
t.Fatalf("日志错误字符串不应包含敏感信息 %q,实际: %s", secret, err.Error())
|
||||
}
|
||||
}
|
||||
if !strings.Contains(userMessage, "<redacted>") {
|
||||
t.Fatalf("用户提示应包含脱敏占位符,实际: %s", userMessage)
|
||||
}
|
||||
}
|
||||
@@ -13,14 +13,15 @@ import (
|
||||
func TestSetupStaticFiles(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
distFS, err := frontendDistFS()
|
||||
if err != nil {
|
||||
t.Skipf("跳过测试: 前端资源未构建: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
setupStaticFilesWithFS(r, distFS)
|
||||
setupStaticFilesWithFS(r, fstest.MapFS{
|
||||
"index.html": {Data: []byte("<html>fallback</html>")},
|
||||
"icon.png": {Data: []byte("png")},
|
||||
"assets/test.js": {Data: []byte("console.log('test')")},
|
||||
"assets/test.css": {Data: []byte("body {}")},
|
||||
"assets/test.svg": {Data: []byte("<svg></svg>")},
|
||||
"assets/test.woff": {Data: []byte("font")},
|
||||
})
|
||||
|
||||
t.Run("API 404", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
@@ -73,13 +74,12 @@ func TestSetupStaticFiles(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == 200 {
|
||||
expected := "application/javascript"
|
||||
if w.Header().Get("Content-Type") != expected {
|
||||
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
||||
}
|
||||
} else {
|
||||
t.Log("文件不存在,跳过 MIME 类型验证")
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
expected := "application/javascript"
|
||||
if w.Header().Get("Content-Type") != expected {
|
||||
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -88,13 +88,12 @@ func TestSetupStaticFiles(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == 200 {
|
||||
expected := "text/css"
|
||||
if w.Header().Get("Content-Type") != expected {
|
||||
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
||||
}
|
||||
} else {
|
||||
t.Log("文件不存在,跳过 MIME 类型验证")
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
expected := "text/css"
|
||||
if w.Header().Get("Content-Type") != expected {
|
||||
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -128,12 +127,6 @@ func TestSetupStaticFilesWithFS_IconPNG(t *testing.T) {
|
||||
func TestWithProtocolAndStaticRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
distFS, err := frontendDistFS()
|
||||
if err != nil {
|
||||
t.Skipf("跳过测试: 前端资源未构建: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
|
||||
var gotProtocol string
|
||||
@@ -148,7 +141,10 @@ func TestWithProtocolAndStaticRoutes(t *testing.T) {
|
||||
gotPath = c.Param("path")
|
||||
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
|
||||
}))
|
||||
setupStaticFilesWithFS(r, distFS)
|
||||
setupStaticFilesWithFS(r, fstest.MapFS{
|
||||
"index.html": {Data: []byte("<html>fallback</html>")},
|
||||
"assets/test.js": {Data: []byte("console.log('test')")},
|
||||
})
|
||||
|
||||
t.Run("OpenAI route enters proxy handler wrapper", func(t *testing.T) {
|
||||
gotProtocol = ""
|
||||
@@ -199,14 +195,11 @@ func TestWithProtocolAndStaticRoutes(t *testing.T) {
|
||||
if gotProtocol != "" || gotPath != "" {
|
||||
t.Errorf("静态资源不应进入代理包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
|
||||
}
|
||||
if w.Code == http.StatusOK {
|
||||
if !strings.HasPrefix(w.Header().Get("Content-Type"), "application/javascript") {
|
||||
t.Errorf("期望 JS Content-Type, 实际 %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
return
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望静态资源返回 200, 实际 %d", w.Code)
|
||||
}
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望静态资源返回 200 或 404, 实际 %d", w.Code)
|
||||
if !strings.HasPrefix(w.Header().Get("Content-Type"), "application/javascript") {
|
||||
t.Errorf("期望 JS Content-Type, 实际 %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
231
backend/cmd/desktop/tray.go
Normal file
231
backend/cmd/desktop/tray.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"nex/embedfs"
|
||||
|
||||
"github.com/getlantern/systray"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const defaultTrayReadyTimeout = 5 * time.Second
|
||||
|
||||
type trayMenuItem interface {
|
||||
Disable()
|
||||
Clicked() <-chan struct{}
|
||||
}
|
||||
|
||||
type trayController interface {
|
||||
Run(onReady func(), onExit func())
|
||||
Quit()
|
||||
SetIcon(icon []byte)
|
||||
SetTooltip(tooltip string)
|
||||
AddMenuItem(title, tooltip string) trayMenuItem
|
||||
AddSeparator()
|
||||
}
|
||||
|
||||
type realTrayController struct{}
|
||||
|
||||
func (realTrayController) Run(onReady func(), onExit func()) {
|
||||
systray.Run(onReady, onExit)
|
||||
}
|
||||
|
||||
func (realTrayController) Quit() {
|
||||
systray.Quit()
|
||||
}
|
||||
|
||||
func (realTrayController) SetIcon(icon []byte) {
|
||||
systray.SetIcon(icon)
|
||||
}
|
||||
|
||||
func (realTrayController) SetTooltip(tooltip string) {
|
||||
systray.SetTooltip(tooltip)
|
||||
}
|
||||
|
||||
func (realTrayController) AddMenuItem(title, tooltip string) trayMenuItem {
|
||||
return realTrayMenuItem{item: systray.AddMenuItem(title, tooltip)}
|
||||
}
|
||||
|
||||
func (realTrayController) AddSeparator() {
|
||||
systray.AddSeparator()
|
||||
}
|
||||
|
||||
type realTrayMenuItem struct {
|
||||
item *systray.MenuItem
|
||||
}
|
||||
|
||||
func (m realTrayMenuItem) Disable() {
|
||||
m.item.Disable()
|
||||
}
|
||||
|
||||
func (m realTrayMenuItem) Clicked() <-chan struct{} {
|
||||
return m.item.ClickedCh
|
||||
}
|
||||
|
||||
type trayOptions struct {
|
||||
controller trayController
|
||||
readyTimeout time.Duration
|
||||
iconLoader func() ([]byte, error)
|
||||
openBrowser func(string) error
|
||||
notify func(string, string)
|
||||
logger *zap.Logger
|
||||
fatalErrCh <-chan error
|
||||
}
|
||||
|
||||
func setupSystray(port int, fatalErrCh <-chan error) error {
|
||||
return runSystray(port, trayOptions{
|
||||
controller: realTrayController{},
|
||||
readyTimeout: defaultTrayReadyTimeout,
|
||||
iconLoader: loadTrayIcon,
|
||||
openBrowser: openBrowser,
|
||||
notify: showError,
|
||||
logger: dialogLogger(),
|
||||
fatalErrCh: fatalErrCh,
|
||||
})
|
||||
}
|
||||
|
||||
func runSystray(port int, opts trayOptions) error {
|
||||
if opts.controller == nil {
|
||||
opts.controller = realTrayController{}
|
||||
}
|
||||
if opts.readyTimeout <= 0 {
|
||||
opts.readyTimeout = defaultTrayReadyTimeout
|
||||
}
|
||||
if opts.iconLoader == nil {
|
||||
opts.iconLoader = loadTrayIcon
|
||||
}
|
||||
if opts.openBrowser == nil {
|
||||
opts.openBrowser = openBrowser
|
||||
}
|
||||
if opts.notify == nil {
|
||||
opts.notify = showError
|
||||
}
|
||||
if opts.logger == nil {
|
||||
opts.logger = dialogLogger()
|
||||
}
|
||||
|
||||
readyCh := make(chan struct{})
|
||||
doneCh := make(chan struct{})
|
||||
errCh := make(chan error, 1)
|
||||
var readyOnce sync.Once
|
||||
var errOnce sync.Once
|
||||
|
||||
signalReady := func() {
|
||||
readyOnce.Do(func() { close(readyCh) })
|
||||
}
|
||||
signalError := func(err error) {
|
||||
errOnce.Do(func() { errCh <- err })
|
||||
}
|
||||
|
||||
go monitorTrayStartup(port, opts, readyCh, doneCh, signalError)
|
||||
|
||||
opts.controller.Run(func() {
|
||||
handleTrayReady(port, opts, signalReady, signalError)
|
||||
}, nil)
|
||||
close(doneCh)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return err
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func monitorTrayStartup(port int, opts trayOptions, readyCh <-chan struct{}, doneCh <-chan struct{}, signalError func(error)) {
|
||||
timer := time.NewTimer(opts.readyTimeout)
|
||||
defer timer.Stop()
|
||||
|
||||
ready := false
|
||||
for {
|
||||
select {
|
||||
case <-readyCh:
|
||||
ready = true
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
openDesktopBrowser(port, opts)
|
||||
readyCh = nil
|
||||
case <-timer.C:
|
||||
if !ready {
|
||||
signalError(newStartupError(phaseTray, "托盘初始化超时", fmt.Errorf("托盘未在 %s 内 ready", opts.readyTimeout)))
|
||||
opts.controller.Quit()
|
||||
}
|
||||
case err := <-opts.fatalErrCh:
|
||||
if err != nil {
|
||||
signalError(newStartupError(phaseServer, startupServerErrorMessage(), err))
|
||||
opts.controller.Quit()
|
||||
}
|
||||
case <-doneCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleTrayReady(port int, opts trayOptions, signalReady func(), signalError func(error)) {
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
err := fmt.Errorf("托盘初始化 panic: %v", recovered)
|
||||
signalError(newStartupError(phaseTray, "托盘菜单初始化失败", err))
|
||||
opts.controller.Quit()
|
||||
}
|
||||
}()
|
||||
|
||||
icon, err := opts.iconLoader()
|
||||
if err != nil {
|
||||
signalError(newStartupError(phaseTray, "托盘图标资源无法加载", err))
|
||||
opts.controller.Quit()
|
||||
return
|
||||
}
|
||||
|
||||
opts.controller.SetIcon(icon)
|
||||
opts.controller.SetTooltip(appTooltip)
|
||||
|
||||
mOpen := opts.controller.AddMenuItem("打开管理界面", "在浏览器中打开")
|
||||
opts.controller.AddSeparator()
|
||||
mStatus := opts.controller.AddMenuItem("状态: 运行中", "")
|
||||
mStatus.Disable()
|
||||
mPort := opts.controller.AddMenuItem(desktopPortMenuTitle(port), "")
|
||||
mPort.Disable()
|
||||
opts.controller.AddSeparator()
|
||||
mQuit := opts.controller.AddMenuItem("退出", "停止服务并退出")
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-mOpen.Clicked():
|
||||
if err := opts.openBrowser(desktopURL(port)); err != nil {
|
||||
opts.logger.Warn("打开浏览器失败", zap.Error(err))
|
||||
}
|
||||
case <-mQuit.Clicked():
|
||||
doShutdown()
|
||||
opts.controller.Quit()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
signalReady()
|
||||
}
|
||||
|
||||
func openDesktopBrowser(port int, opts trayOptions) {
|
||||
url := desktopURL(port)
|
||||
if err := opts.openBrowser(url); err != nil {
|
||||
opts.logger.Warn("无法打开浏览器", zap.Error(err))
|
||||
opts.notify(appName, fmt.Sprintf("无法自动打开浏览器,请手动访问 %s", url))
|
||||
}
|
||||
}
|
||||
|
||||
func loadTrayIcon() ([]byte, error) {
|
||||
if runtime.GOOS == "windows" {
|
||||
return embedfs.Assets.ReadFile("assets/icon.ico")
|
||||
}
|
||||
return embedfs.Assets.ReadFile("assets/icon.png")
|
||||
}
|
||||
169
backend/cmd/desktop/tray_test.go
Normal file
169
backend/cmd/desktop/tray_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type fakeTrayController struct {
|
||||
run func(onReady func(), onExit func())
|
||||
|
||||
quitCh chan struct{}
|
||||
quitOnce sync.Once
|
||||
|
||||
icon []byte
|
||||
tooltip string
|
||||
menuItems []*fakeTrayMenuItem
|
||||
}
|
||||
|
||||
func newFakeTrayController() *fakeTrayController {
|
||||
return &fakeTrayController{quitCh: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (c *fakeTrayController) Run(onReady func(), onExit func()) {
|
||||
if c.run != nil {
|
||||
c.run(onReady, onExit)
|
||||
return
|
||||
}
|
||||
onReady()
|
||||
<-c.quitCh
|
||||
if onExit != nil {
|
||||
onExit()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *fakeTrayController) Quit() {
|
||||
c.quitOnce.Do(func() { close(c.quitCh) })
|
||||
}
|
||||
|
||||
func (c *fakeTrayController) SetIcon(icon []byte) {
|
||||
c.icon = append([]byte(nil), icon...)
|
||||
}
|
||||
|
||||
func (c *fakeTrayController) SetTooltip(tooltip string) {
|
||||
c.tooltip = tooltip
|
||||
}
|
||||
|
||||
func (c *fakeTrayController) AddMenuItem(title, tooltip string) trayMenuItem {
|
||||
item := &fakeTrayMenuItem{clicked: make(chan struct{}), title: title, tooltip: tooltip}
|
||||
c.menuItems = append(c.menuItems, item)
|
||||
return item
|
||||
}
|
||||
|
||||
func (c *fakeTrayController) AddSeparator() {}
|
||||
|
||||
type fakeTrayMenuItem struct {
|
||||
clicked chan struct{}
|
||||
title string
|
||||
tooltip string
|
||||
disabled bool
|
||||
}
|
||||
|
||||
func (m *fakeTrayMenuItem) Disable() {
|
||||
m.disabled = true
|
||||
}
|
||||
|
||||
func (m *fakeTrayMenuItem) Clicked() <-chan struct{} {
|
||||
return m.clicked
|
||||
}
|
||||
|
||||
func TestRunSystrayReadyOpensBrowser(t *testing.T) {
|
||||
controller := newFakeTrayController()
|
||||
opened := make(chan string, 1)
|
||||
|
||||
err := runSystray(19826, trayOptions{
|
||||
controller: controller,
|
||||
readyTimeout: time.Second,
|
||||
iconLoader: func() ([]byte, error) { return []byte("icon"), nil },
|
||||
openBrowser: func(url string) error {
|
||||
opened <- url
|
||||
controller.Quit()
|
||||
return nil
|
||||
},
|
||||
notify: func(string, string) {},
|
||||
logger: zap.NewNop(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("托盘 ready 成功不应返回错误: %v", err)
|
||||
}
|
||||
|
||||
if got := <-opened; got != "http://localhost:19826" {
|
||||
t.Fatalf("浏览器 URL = %s", got)
|
||||
}
|
||||
if string(controller.icon) != "icon" {
|
||||
t.Fatalf("应设置托盘图标")
|
||||
}
|
||||
if controller.tooltip != appTooltip {
|
||||
t.Fatalf("tooltip = %q, want %q", controller.tooltip, appTooltip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunSystrayReadyTimeoutReturnsTrayStartupError(t *testing.T) {
|
||||
controller := newFakeTrayController()
|
||||
controller.run = func(_ func(), _ func()) {
|
||||
<-controller.quitCh
|
||||
}
|
||||
|
||||
err := runSystray(19826, trayOptions{
|
||||
controller: controller,
|
||||
readyTimeout: 10 * time.Millisecond,
|
||||
iconLoader: func() ([]byte, error) { return []byte("icon"), nil },
|
||||
openBrowser: func(string) error { return nil },
|
||||
notify: func(string, string) {},
|
||||
logger: zap.NewNop(),
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("托盘 ready timeout 应返回错误")
|
||||
}
|
||||
var startupErr *startupError
|
||||
if !errors.As(err, &startupErr) || startupErr.Phase() != "tray" {
|
||||
t.Fatalf("应返回 tray 阶段启动错误,实际: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunSystrayIconLoadFailureReturnsTrayStartupError(t *testing.T) {
|
||||
controller := newFakeTrayController()
|
||||
|
||||
err := runSystray(19826, trayOptions{
|
||||
controller: controller,
|
||||
readyTimeout: time.Second,
|
||||
iconLoader: func() ([]byte, error) { return nil, errors.New("missing icon") },
|
||||
openBrowser: func(string) error { return nil },
|
||||
notify: func(string, string) {},
|
||||
logger: zap.NewNop(),
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("托盘图标加载失败应返回错误")
|
||||
}
|
||||
var startupErr *startupError
|
||||
if !errors.As(err, &startupErr) || startupErr.Phase() != "tray" {
|
||||
t.Fatalf("应返回 tray 阶段启动错误,实际: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunSystrayBrowserOpenFailureIsNonFatal(t *testing.T) {
|
||||
controller := newFakeTrayController()
|
||||
notified := make(chan string, 1)
|
||||
|
||||
err := runSystray(19826, trayOptions{
|
||||
controller: controller,
|
||||
readyTimeout: time.Second,
|
||||
iconLoader: func() ([]byte, error) { return []byte("icon"), nil },
|
||||
openBrowser: func(string) error { return errors.New("no browser") },
|
||||
notify: func(_, message string) {
|
||||
notified <- message
|
||||
controller.Quit()
|
||||
},
|
||||
logger: zap.NewNop(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("浏览器打开失败不应成为 fatal: %v", err)
|
||||
}
|
||||
if got := <-notified; got == "" {
|
||||
t.Fatal("浏览器打开失败应提示用户")
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -17,6 +18,8 @@ import (
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
var ErrMigration = errors.New("数据库迁移失败")
|
||||
|
||||
func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
|
||||
moduleLogger := pkglogger.WithModule(zapLogger, "database")
|
||||
|
||||
@@ -26,7 +29,7 @@ func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
|
||||
}
|
||||
|
||||
if err := runMigrations(db, cfg.Driver, moduleLogger); err != nil {
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
return nil, fmt.Errorf("%w: %w", ErrMigration, err)
|
||||
}
|
||||
|
||||
configurePool(db, cfg, moduleLogger)
|
||||
|
||||
Reference in New Issue
Block a user