1
0

feat: 增强桌面启动失败提示与测试覆盖

This commit is contained in:
2026-05-08 23:42:48 +08:00
parent c524e8f928
commit 2dec9e5c54
21 changed files with 1857 additions and 297 deletions

View File

@@ -4,17 +4,35 @@ package main
import (
"fmt"
"os/exec"
"strings"
"go.uber.org/zap"
)
func showError(title, message string) {
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`,
escapeAppleScript(message), escapeAppleScript(title))
if err := exec.Command("osascript", "-e", script).Run(); err != nil {
dialogLogger().Warn("显示错误对话框失败", zap.Error(err))
func platformStartupChannels(runner commandRunner) []promptChannel {
return []promptChannel{
{
name: "macos-notification",
available: func() error {
_, err := runner.LookPath("osascript")
return err
},
run: func(req promptRequest) error {
script := fmt.Sprintf(`display notification "%s" with title "%s" subtitle "%s"`,
escapeAppleScript(req.message), escapeAppleScript(req.title), escapeAppleScript(req.subtitle))
return runner.Run(promptCommandTimeout, nil, "osascript", "-e", script)
},
},
{
name: "macos-alert",
available: func() error {
_, err := runner.LookPath("osascript")
return err
},
run: func(req promptRequest) error {
script := fmt.Sprintf(`display alert "%s" message "%s" as critical buttons {"OK"} default button "OK"`,
escapeAppleScript(req.title), escapeAppleScript(req.message))
return runner.Run(promptCommandTimeout, nil, "osascript", "-e", script)
},
},
}
}

View File

@@ -0,0 +1,46 @@
//go:build darwin
package main
import (
"strings"
"testing"
)
func TestDarwinStartupChannelsBuildNotificationAndAlert(t *testing.T) {
runner := &fakeCommandRunner{paths: map[string]bool{"osascript": true}}
channels := platformStartupChannels(runner)
if len(channels) != 2 {
t.Fatalf("macOS 应有 notification 和 alert 两级通道,实际: %d", len(channels))
}
req := promptRequest{title: "Nex 启动失败", subtitle: "config", message: "路径 C:\\tmp 包含 \"quote\""}
for _, channel := range channels {
if err := channel.available(); err != nil {
t.Fatalf("通道 %s 应可用: %v", channel.name, err)
}
if err := channel.run(req); err != nil {
t.Fatalf("通道 %s 执行失败: %v", channel.name, err)
}
}
if len(runner.calls) != 2 {
t.Fatalf("应执行两次 osascript实际: %d", len(runner.calls))
}
if runner.calls[0].name != "osascript" || runner.calls[0].args[0] != "-e" {
t.Fatalf("notification 命令参数错误: %#v", runner.calls[0])
}
if script := runner.calls[0].args[1]; !strings.Contains(script, "display notification") || !strings.Contains(script, `\\tmp`) || !strings.Contains(script, `\"quote\"`) {
t.Fatalf("notification AppleScript 未正确构造或转义: %s", script)
}
if script := runner.calls[1].args[1]; !strings.Contains(script, "display alert") || !strings.Contains(script, "as critical") {
t.Fatalf("alert AppleScript 未使用 critical 告警: %s", script)
}
}
func TestEscapeAppleScript(t *testing.T) {
got := escapeAppleScript(`C:\tmp "quote"`)
if !strings.Contains(got, `C:\\tmp`) || !strings.Contains(got, `\"quote\"`) {
t.Fatalf("AppleScript 转义结果错误: %s", got)
}
}

View File

@@ -3,8 +3,9 @@
package main
import (
"errors"
"fmt"
"os/exec"
"os"
"sync"
)
@@ -12,56 +13,99 @@ type dialogToolType int
const (
toolNone dialogToolType = iota
toolZenity
toolKdialog
toolNotifySend
toolKdialogPassive
toolZenity
toolKdialogError
toolXmessage
)
var (
dialogTool dialogToolType
dialogToolOnce sync.Once
dialogTools map[string]bool
dialogToolOnce sync.Once
dialogToolNames = []string{"notify-send", "kdialog", "zenity", "xmessage"}
)
func init() {
dialogToolOnce.Do(detectDialogTool)
dialogToolOnce.Do(func() { detectDialogTools(defaultCommandRunner{}) })
}
func detectDialogTool() {
tools := []struct {
name string
typ dialogToolType
}{
{"zenity", toolZenity},
{"kdialog", toolKdialog},
{"notify-send", toolNotifySend},
{"xmessage", toolXmessage},
func platformStartupChannels(runner commandRunner) []promptChannel {
return []promptChannel{
linuxCommandChannel("notify-send", toolNotifySend, runner, linuxHasGraphicalSessionAndDBus, func(req promptRequest) []string {
return []string{"-u", "critical", "-a", appName, "-i", "nex", req.title, req.message}
}),
linuxCommandChannel("kdialog", toolKdialogPassive, runner, linuxHasGraphicalSession, func(req promptRequest) []string {
return []string{"--title", req.title, "--passivepopup", req.message, "10"}
}),
linuxCommandChannel("zenity", toolZenity, runner, linuxHasGraphicalSession, func(req promptRequest) []string {
return []string{"--error", fmt.Sprintf("--title=%s", req.title), fmt.Sprintf("--text=%s", req.message)}
}),
linuxCommandChannel("kdialog", toolKdialogError, runner, linuxHasGraphicalSession, func(req promptRequest) []string {
return []string{"--title", req.title, "--error", req.message}
}),
linuxCommandChannel("xmessage", toolXmessage, runner, linuxHasX11Display, func(req promptRequest) []string {
return []string{"-center", "-buttons", "OK:0", "-default", "OK", fmt.Sprintf("%s: %s", req.title, req.message)}
}),
}
}
for _, tool := range tools {
if _, err := exec.LookPath(tool.name); err == nil {
dialogTool = tool.typ
return
func detectDialogTools(runner commandRunner) {
dialogTools = make(map[string]bool, len(dialogToolNames))
for _, name := range dialogToolNames {
_, err := runner.LookPath(name)
dialogTools[name] = err == nil
}
}
func linuxCommandChannel(name string, typ dialogToolType, runner commandRunner, environmentOK func() error, args func(promptRequest) []string) promptChannel {
return promptChannel{
name: fmt.Sprintf("linux-%s-%d", name, typ),
available: func() error {
if err := linuxCommandAvailable(runner, name); err != nil {
return err
}
return environmentOK()
},
run: func(req promptRequest) error {
return runner.Run(promptCommandTimeout, nil, name, args(req)...)
},
}
}
func linuxCommandAvailable(runner commandRunner, name string) error {
if _, ok := runner.(defaultCommandRunner); ok {
dialogToolOnce.Do(func() { detectDialogTools(runner) })
if dialogTools[name] {
return nil
}
return fmt.Errorf("%s 不可用", name)
}
dialogTool = toolNone
_, err := runner.LookPath(name)
return err
}
func showError(title, message string) {
switch dialogTool {
case toolZenity:
exec.Command("zenity", "--error",
fmt.Sprintf("--title=%s", title),
fmt.Sprintf("--text=%s", message)).Run()
case toolKdialog:
exec.Command("kdialog", "--error", message, "--title", title).Run()
case toolNotifySend:
exec.Command("notify-send", "-u", "critical", title, message).Run()
case toolXmessage:
exec.Command("xmessage", "-center",
fmt.Sprintf("%s: %s", title, message)).Run()
default:
dialogLogger().Error("无法显示错误对话框")
func linuxHasGraphicalSession() error {
if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" {
return errors.New("缺少图形会话")
}
return nil
}
func linuxHasGraphicalSessionAndDBus() error {
if err := linuxHasGraphicalSession(); err != nil {
return err
}
if os.Getenv("DBUS_SESSION_BUS_ADDRESS") == "" {
return errors.New("缺少 DBus session bus")
}
return nil
}
func linuxHasX11Display() error {
if os.Getenv("DISPLAY") == "" {
return errors.New("缺少 X11 DISPLAY")
}
return nil
}

View File

@@ -0,0 +1,61 @@
//go:build linux
package main
import "testing"
func TestLinuxStartupChannelsPriorityAndArguments(t *testing.T) {
t.Setenv("DISPLAY", ":0")
t.Setenv("DBUS_SESSION_BUS_ADDRESS", "unix:path=/tmp/dbus")
runner := &fakeCommandRunner{paths: map[string]bool{
"notify-send": true,
"kdialog": true,
"zenity": true,
"xmessage": true,
}}
channels := platformStartupChannels(runner)
if len(channels) != 5 {
t.Fatalf("Linux 应有 5 个 UI 通道,实际: %d", len(channels))
}
req := promptRequest{title: "Nex 启动失败", message: "端口被占用"}
for _, channel := range channels {
if err := channel.available(); err != nil {
t.Fatalf("通道 %s 应可用: %v", channel.name, err)
}
if err := channel.run(req); err != nil {
t.Fatalf("通道 %s 执行失败: %v", channel.name, err)
}
}
wantNames := []string{"notify-send", "kdialog", "zenity", "kdialog", "xmessage"}
for i, want := range wantNames {
if got := runner.calls[i].name; got != want {
t.Fatalf("第 %d 个命令 = %s, want %s", i, got, want)
}
}
if got := runner.calls[0].args; len(got) < 2 || got[0] != "-u" || got[1] != "critical" {
t.Fatalf("notify-send 应使用 critical 参数,实际: %#v", got)
}
if got := runner.calls[1].args; len(got) < 3 || got[2] != "--passivepopup" {
t.Fatalf("kdialog 第一跳应使用 passivepopup实际: %#v", got)
}
if got := runner.calls[2].args; len(got) < 1 || got[0] != "--error" {
t.Fatalf("zenity 应使用 --error实际: %#v", got)
}
if got := runner.calls[4].args; len(got) < 1 || got[0] != "-center" {
t.Fatalf("xmessage 应居中显示,实际: %#v", got)
}
}
func TestLinuxNotifySendRequiresDBus(t *testing.T) {
t.Setenv("DISPLAY", ":0")
t.Setenv("DBUS_SESSION_BUS_ADDRESS", "")
runner := &fakeCommandRunner{paths: map[string]bool{"notify-send": true}}
channels := platformStartupChannels(runner)
if err := channels[0].available(); err == nil {
t.Fatal("notify-send 缺少 DBus session bus 时应不可用")
}
}

View File

@@ -3,17 +3,21 @@
package main
import (
"encoding/base64"
"errors"
"fmt"
"syscall"
"unicode/utf16"
"unsafe"
"go.uber.org/zap"
)
const (
mbOK = 0x00000000
mbIconError = 0x10
mbIconInformation = 0x40
mbTaskModal = 0x00002000
mbSetForeground = 0x00010000
mbTopMost = 0x00040000
)
var (
@@ -25,12 +29,79 @@ var (
}
)
func showError(title, message string) {
if err := messageBox(title, message, mbIconError); err != nil {
if zapLogger != nil {
zapLogger.Warn("显示错误对话框失败", zap.Error(err))
func platformStartupChannels(runner commandRunner) []promptChannel {
return []promptChannel{
{
name: "windows-toast",
available: func() error {
_, err := findPowerShell(runner)
return err
},
run: func(req promptRequest) error {
name, err := findPowerShell(runner)
if err != nil {
return err
}
return runner.Run(promptCommandTimeout, []string{
"NEX_TOAST_TITLE=" + req.title,
"NEX_TOAST_BODY=" + req.message,
}, name, "-NoProfile", "-NonInteractive", "-ExecutionPolicy", "Bypass", "-EncodedCommand", encodePowerShellCommand(windowsToastScript()))
},
},
{
name: "windows-messagebox",
available: func() error {
return messageBoxAvailable()
},
run: func(req promptRequest) error {
return messageBox(req.title, req.message, messageBoxStartupFlags())
},
},
}
}
func findPowerShell(runner commandRunner) (string, error) {
for _, name := range []string{"powershell.exe", "powershell"} {
if _, err := runner.LookPath(name); err == nil {
return name, nil
}
}
return "", fmt.Errorf("PowerShell 不可用")
}
func windowsToastScript() string {
return `$ErrorActionPreference = 'Stop'
Add-Type -AssemblyName System.Runtime.WindowsRuntime
$template = [Windows.UI.Notifications.ToastTemplateType]::ToastText02
$xml = [Windows.UI.Notifications.ToastNotificationManager]::GetTemplateContent($template)
$texts = $xml.GetElementsByTagName('text')
$texts.Item(0).AppendChild($xml.CreateTextNode($env:NEX_TOAST_TITLE)) | Out-Null
$texts.Item(1).AppendChild($xml.CreateTextNode($env:NEX_TOAST_BODY)) | Out-Null
$toast = [Windows.UI.Notifications.ToastNotification]::new($xml)
[Windows.UI.Notifications.ToastNotificationManager]::CreateToastNotifier('Nex').Show($toast)`
}
func encodePowerShellCommand(script string) string {
encoded := utf16.Encode([]rune(script))
buf := make([]byte, 0, len(encoded)*2)
for _, value := range encoded {
buf = append(buf, byte(value), byte(value>>8))
}
return base64.StdEncoding.EncodeToString(buf)
}
func messageBoxAvailable() error {
if _, err := syscall.UTF16PtrFromString("Nex"); err != nil {
return err
}
if _, err := syscall.UTF16PtrFromString("test"); err != nil {
return err
}
return procMessageBoxW.Find()
}
func messageBoxStartupFlags() uint {
return mbOK | mbIconError | mbTaskModal | mbSetForeground | mbTopMost
}
func messageBox(title, message string, flags uint) error {

View File

@@ -2,6 +2,7 @@ package main
import (
"context"
"errors"
"fmt"
"io/fs"
"net"
@@ -27,10 +28,10 @@ import (
"nex/backend/internal/service"
"nex/backend/pkg/buildinfo"
"github.com/getlantern/systray"
"github.com/gin-gonic/gin"
"github.com/gofrs/flock"
"go.uber.org/zap"
"gorm.io/gorm"
pkgLogger "nex/backend/pkg/logger"
)
@@ -40,31 +41,65 @@ var (
zapLogger *zap.Logger
shutdownCtx context.Context
shutdownCancel context.CancelFunc
desktopHooks = defaultDesktopRuntimeHooks()
)
type singletonLocker interface {
Lock() error
Unlock() error
}
type desktopRuntimeHooks struct {
loadConfig func() (*config.Config, config.ConfigMetadata, error)
newLock func(string) singletonLocker
listen func(int) (net.Listener, error)
upgradeLogger func(*zap.Logger, pkgLogger.Config) (*zap.Logger, error)
initDB func(*config.DatabaseConfig, *zap.Logger) (*gorm.DB, error)
closeDB func(*gorm.DB)
registerAdapters func(conversion.AdapterRegistry) error
setupStaticFiles func(*gin.Engine) error
startServer func(*http.Server, net.Listener, chan<- error, *zap.Logger)
setupSystray func(int, <-chan error) error
}
func defaultDesktopRuntimeHooks() desktopRuntimeHooks {
return desktopRuntimeHooks{
loadConfig: config.LoadDesktopConfigWithMetadata,
newLock: func(lockPath string) singletonLocker { return NewSingletonLock(lockPath) },
listen: listenDesktopPort,
upgradeLogger: pkgLogger.Upgrade,
initDB: database.Init,
closeDB: database.Close,
registerAdapters: registerDesktopAdapters,
setupStaticFiles: setupStaticFiles,
startServer: startDesktopServer,
setupSystray: setupSystray,
}
}
func main() {
minimalLogger := pkgLogger.NewMinimal()
cfg, cfgMeta, err := config.LoadDesktopConfigWithMetadata()
if err != nil {
minimalLogger.Error("加载配置失败", zap.Error(err))
showError(appName, desktopConfigErrorMessage(getDesktopConfigPath(), err))
if err := runDesktop(minimalLogger); err != nil {
reportStartupFailure(err, dialogLogger())
os.Exit(1)
}
}
func runDesktop(minimalLogger *zap.Logger) error {
if minimalLogger == nil {
minimalLogger = pkgLogger.NewMinimal()
}
cfg, cfgMeta, err := desktopHooks.loadConfig()
if err != nil {
return newStartupError(phaseConfig, desktopConfigErrorMessage(getDesktopConfigPath(), err), err)
}
port := cfg.Server.Port
if err := checkPortAvailable(port); err != nil {
minimalLogger.Error("端口不可用", zap.Error(err))
showError(appName, err.Error())
os.Exit(1)
}
singleLock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway.lock"))
singleLock := desktopHooks.newLock(filepath.Join(os.TempDir(), "nex-gateway.lock"))
if err := singleLock.Lock(); err != nil {
minimalLogger.Error("已有 Nex 实例运行")
showError(appName, "已有 Nex 实例运行")
os.Exit(1)
return newStartupError(phaseSingleton, "已有 Nex 实例运行", err)
}
defer func() {
if err := singleLock.Unlock(); err != nil {
@@ -72,7 +107,13 @@ func main() {
}
}()
zapLogger, err = pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
listener, err := desktopHooks.listen(port)
if err != nil {
return newStartupError(phasePort, desktopPortUnavailableMessage(port), err)
}
defer listener.Close()
zapLogger, err = desktopHooks.upgradeLogger(minimalLogger, pkgLogger.Config{
Level: cfg.Log.Level,
Path: cfg.Log.Path,
MaxSize: cfg.Log.MaxSize,
@@ -81,7 +122,7 @@ func main() {
Compress: cfg.Log.Compress,
})
if err != nil {
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
return newStartupError(phaseLogger, fmt.Sprintf("初始化日志失败\n\n日志目录: %s\n\n请检查目录权限或磁盘空间", cfg.Log.Path), err)
}
defer func() {
if err := zapLogger.Sync(); err != nil {
@@ -91,11 +132,17 @@ func main() {
cfg.PrintSummary(zapLogger)
db, err := database.Init(&cfg.Database, zapLogger)
db, err := desktopHooks.initDB(&cfg.Database, zapLogger)
if err != nil {
zapLogger.Fatal("初始化数据库失败", zap.Error(err))
phase := phaseDatabase
message := fmt.Sprintf("数据库初始化失败\n\n请检查数据库配置、文件权限或连接状态\n\n%v", err)
if errors.Is(err, database.ErrMigration) {
phase = phaseMigration
message = fmt.Sprintf("数据库迁移失败\n\n请查看日志或检查数据库迁移权限\n\n%v", err)
}
return newStartupError(phase, message, err)
}
defer database.Close(db)
defer desktopHooks.closeDB(db)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
@@ -118,11 +165,8 @@ func main() {
statsService := service.NewStatsService(statsRepo, statsBuffer)
registry := conversion.NewMemoryRegistry()
if err := registry.Register(openai.NewAdapter()); err != nil {
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
}
if err := registry.Register(anthropic.NewAdapter()); err != nil {
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
if err := desktopHooks.registerAdapters(registry); err != nil {
return newStartupError(phaseAdapter, startupInternalErrorMessage(), err)
}
engine := conversion.NewConversionEngine(registry, zapLogger)
@@ -144,7 +188,9 @@ func main() {
r.Use(middleware.CORS())
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler, versionHandler, settingsHandler)
setupStaticFiles(r)
if err := desktopHooks.setupStaticFiles(r); err != nil {
return newStartupError(phaseStaticResource, startupInternalErrorMessage(), err)
}
server = &http.Server{
Addr: desktopListenAddr(port),
@@ -154,26 +200,46 @@ func main() {
}
shutdownCtx, shutdownCancel = context.WithCancel(context.Background())
defer doShutdown()
serverErrCh := make(chan error, 1)
desktopHooks.startServer(server, listener, serverErrCh, zapLogger)
select {
case err := <-serverErrCh:
return newStartupError(phaseServer, startupServerErrorMessage(), err)
case <-time.After(50 * time.Millisecond):
}
if err := desktopHooks.setupSystray(port, serverErrCh); err != nil {
return err
}
select {
case err := <-serverErrCh:
return newStartupError(phaseServer, startupServerErrorMessage(), err)
default:
return nil
}
}
func registerDesktopAdapters(registry conversion.AdapterRegistry) error {
if err := registry.Register(openai.NewAdapter()); err != nil {
return err
}
return registry.Register(anthropic.NewAdapter())
}
func startDesktopServer(server *http.Server, listener net.Listener, serverErrCh chan<- error, logger *zap.Logger) {
go func() {
zapLogger.Info("AI Gateway 启动",
logger.Info("AI Gateway 启动",
zap.String("addr", server.Addr),
zap.String("version", buildinfo.Version()),
zap.String("commit", buildinfo.Commit()),
zap.String("build_time", buildinfo.BuildTime()))
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
zapLogger.Fatal("服务器启动失败", zap.Error(err))
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
serverErrCh <- err
}
}()
go func() {
time.Sleep(500 * time.Millisecond)
if err := openBrowser(desktopURL(port)); err != nil {
zapLogger.Warn("无法打开浏览器", zap.Error(err))
}
}()
setupSystray(port)
}
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler, versionHandler *handler.VersionHandler, settingsHandler *handler.SettingsHandler) {
@@ -223,12 +289,13 @@ func withProtocol(protocol string, next gin.HandlerFunc) gin.HandlerFunc {
}
}
func setupStaticFiles(r *gin.Engine) {
func setupStaticFiles(r *gin.Engine) error {
distFS, err := frontendDistFS()
if err != nil {
zapLogger.Fatal("无法加载前端资源", zap.Error(err))
return err
}
setupStaticFilesWithFS(r, distFS)
return nil
}
func frontendDistFS() (fs.FS, error) {
@@ -299,47 +366,6 @@ func setupStaticFilesWithFS(r *gin.Engine, distFS fs.FS) {
})
}
func setupSystray(port int) {
systray.Run(func() {
var icon []byte
var err error
if runtime.GOOS == "windows" {
icon, err = embedfs.Assets.ReadFile("assets/icon.ico")
} else {
icon, err = embedfs.Assets.ReadFile("assets/icon.png")
}
if err != nil {
zapLogger.Error("无法加载托盘图标", zap.Error(err))
}
systray.SetIcon(icon)
systray.SetTooltip(appTooltip)
mOpen := systray.AddMenuItem("打开管理界面", "在浏览器中打开")
systray.AddSeparator()
mStatus := systray.AddMenuItem("状态: 运行中", "")
mStatus.Disable()
mPort := systray.AddMenuItem(desktopPortMenuTitle(port), "")
mPort.Disable()
systray.AddSeparator()
mQuit := systray.AddMenuItem("退出", "停止服务并退出")
go func() {
for {
select {
case <-mOpen.ClickedCh:
if err := openBrowser(desktopURL(port)); err != nil {
zapLogger.Warn("打开浏览器失败", zap.Error(err))
}
case <-mQuit.ClickedCh:
doShutdown()
systray.Quit()
return
}
}
}()
}, nil)
}
func doShutdown() {
if zapLogger != nil {
zapLogger.Info("正在关闭服务器...")
@@ -382,13 +408,12 @@ func desktopPortMenuTitle(port int) string {
return fmt.Sprintf("端口: %d", port)
}
func checkPortAvailable(port int) error {
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return fmt.Errorf("端口 %d 已被占用\n\n可能原因:\n- 已有 Nex 实例运行\n- 其他程序占用了该端口\n\n请检查并关闭占用端口的程序", port)
}
ln.Close()
return nil
func listenDesktopPort(port int) (net.Listener, error) {
return net.Listen("tcp", desktopListenAddr(port))
}
func desktopPortUnavailableMessage(port int) string {
return fmt.Sprintf("端口 %d 已被占用\n\n可能原因:\n- 已有 Nex 实例运行\n- 其他程序占用了该端口\n\n请检查并关闭占用端口的程序", port)
}
type SingletonLock struct {

View File

@@ -47,9 +47,15 @@ func TestMessageBoxW_WindowsOnly_FailureUsesReturnValue(t *testing.T) {
}
func TestShowError_WindowsBranch(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 0, syscall.Errno(5)
})
old := buildPromptChannels
buildPromptChannels = func(commandRunner) []promptChannel {
return []promptChannel{{
name: "fake-failed-channel",
available: func() error { return nil },
run: func(promptRequest) error { return syscall.Errno(5) },
}}
}
t.Cleanup(func() { buildPromptChannels = old })
defer func() {
if recovered := recover(); recovered != nil {
@@ -59,3 +65,42 @@ func TestShowError_WindowsBranch(t *testing.T) {
showError("测试错误", "这是一条测试错误消息")
}
func TestMessageBoxW_WindowsOnly_StartupFlags(t *testing.T) {
var gotFlags uintptr
withMessageBoxW(t, func(_, _, _, flags uintptr) (uintptr, error) {
gotFlags = flags
return 1, syscall.Errno(0)
})
if err := messageBox("测试标题", "测试消息", messageBoxStartupFlags()); err != nil {
t.Fatalf("MessageBoxW 应成功: %v", err)
}
for _, flag := range []uint{mbIconError, mbTaskModal, mbSetForeground, mbTopMost} {
if gotFlags&uintptr(flag) == 0 {
t.Fatalf("startup flags 缺少 0x%x实际: 0x%x", flag, gotFlags)
}
}
}
func TestWindowsStartupChannelsUseToastBeforeMessageBox(t *testing.T) {
runner := &fakeCommandRunner{paths: map[string]bool{"powershell.exe": true}}
channels := platformStartupChannels(runner)
if len(channels) != 2 {
t.Fatalf("Windows 应有 Toast 和 MessageBox 两级通道,实际: %d", len(channels))
}
if channels[0].name != "windows-toast" || channels[1].name != "windows-messagebox" {
t.Fatalf("Windows 通道顺序错误: %s, %s", channels[0].name, channels[1].name)
}
if err := channels[0].available(); err != nil {
t.Fatalf("PowerShell 存在时 Toast 通道应可用: %v", err)
}
if err := channels[0].run(promptRequest{title: "Nex 启动失败", message: "端口被占用"}); err != nil {
t.Fatalf("Toast fake runner 应执行成功: %v", err)
}
if len(runner.calls) != 1 || runner.calls[0].name != "powershell.exe" {
t.Fatalf("Toast 应调用 powershell.exe实际: %#v", runner.calls)
}
}

View File

@@ -9,87 +9,27 @@ import (
"time"
)
func TestCheckPortAvailable(t *testing.T) {
port := 19826
err := checkPortAvailable(port)
func TestListenDesktopPortReturnsReusableListener(t *testing.T) {
listener, err := listenDesktopPort(0)
if err != nil {
t.Fatalf("端口 %d 应该可用: %v", port, err)
}
t.Log("端口可用测试通过")
}
func TestCheckPortOccupied(t *testing.T) {
port := 19827
listener, err := net.Listen("tcp", ":19827") //nolint:gosec // 需要验证 checkPortAvailable 对通配地址占用的检测行为
if err != nil {
t.Fatalf("无法启动测试服务器: %v", err)
t.Fatalf("listener-first 应直接获取配置端口 listener: %v", err)
}
defer listener.Close()
time.Sleep(100 * time.Millisecond)
err = checkPortAvailable(port)
if err == nil {
t.Fatal("端口被占用时应该返回错误")
}
t.Log("端口占用检测测试通过")
}
func TestCheckPortAvailableAfterClose(t *testing.T) {
port := 19828
listener, err := net.Listen("tcp", "127.0.0.1:19828")
if err != nil {
t.Fatalf("无法启动测试服务器: %v", err)
}
server := &http.Server{ReadHeaderTimeout: time.Second}
defer server.Close()
done := make(chan struct{})
go func() {
defer close(done)
err := server.Serve(listener)
if err != nil && err != http.ErrServerClosed && !errors.Is(err, net.ErrClosed) {
t.Errorf("serve failed: %v", err)
t.Errorf("使用同一个 listener 启动 server 失败: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
listener.Close()
time.Sleep(100 * time.Millisecond)
err = checkPortAvailable(port)
if err != nil {
t.Fatalf("端口关闭后应该可用: %v", err)
if err := server.Close(); err != nil {
t.Fatalf("关闭测试 server 失败: %v", err)
}
t.Log("端口关闭后可用测试通过")
}
func TestCheckPortAvailableErrorContainsPort(t *testing.T) {
port := 19829
listener, err := net.Listen("tcp", ":19829") //nolint:gosec
if err != nil {
t.Fatalf("无法启动测试服务器: %v", err)
}
defer listener.Close()
time.Sleep(100 * time.Millisecond)
err = checkPortAvailable(port)
if err == nil {
t.Fatal("端口被占用时应该返回错误")
}
if !strings.Contains(err.Error(), "19829") {
t.Fatalf("错误信息应包含端口号 19829实际: %v", err)
}
t.Log("端口错误信息包含端口号测试通过")
<-done
}
func TestGetDesktopConfigPath(t *testing.T) {

View File

@@ -0,0 +1,121 @@
package main
import (
"context"
"errors"
"io"
"os"
"os/exec"
"time"
"go.uber.org/zap"
)
const promptCommandTimeout = 5 * time.Second
type promptRequest struct {
title string
message string
subtitle string
}
type promptChannel struct {
name string
available func() error
run func(promptRequest) error
}
type commandRunner interface {
LookPath(file string) (string, error)
Run(timeout time.Duration, env []string, name string, args ...string) error
}
type defaultCommandRunner struct{}
var buildPromptChannels = platformStartupChannels
func (defaultCommandRunner) LookPath(file string) (string, error) {
return exec.LookPath(file)
}
func (defaultCommandRunner) Run(timeout time.Duration, env []string, name string, args ...string) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
cmd := exec.CommandContext(ctx, name, args...)
if len(env) > 0 {
cmd.Env = append(os.Environ(), env...)
}
if err := cmd.Run(); err != nil {
return err
}
if err := ctx.Err(); err != nil {
return err
}
return nil
}
func showError(title, message string) {
reportPrompt(promptRequest{title: title, message: message}, os.Stderr, dialogLogger())
}
func reportStartupFailure(err error, logger *zap.Logger) {
if err == nil {
return
}
var startupErr *startupError
if !errors.As(err, &startupErr) {
startupErr = newStartupError(phaseServer, startupServerErrorMessage(), err)
}
if logger == nil {
logger = dialogLogger()
}
logger.Error("desktop 启动失败",
zap.String("phase", startupErr.Phase()),
zap.Error(startupErr))
reportPrompt(promptRequest{
title: startupTitle(),
message: startupErr.UserMessage(),
subtitle: startupErr.Phase(),
}, os.Stderr, logger)
}
func reportPrompt(req promptRequest, fallback io.Writer, logger *zap.Logger) {
runPromptPipeline(req, buildPromptChannels(defaultCommandRunner{}), fallback, logger)
}
func runPromptPipeline(req promptRequest, channels []promptChannel, fallback io.Writer, logger *zap.Logger) {
if logger == nil {
logger = dialogLogger()
}
for _, channel := range channels {
if channel.available != nil {
if err := channel.available(); err != nil {
logger.Warn("提示通道不可用", zap.String("channel", channel.name), zap.Error(err))
continue
}
}
if err := channel.run(req); err != nil {
logger.Warn("提示通道执行失败", zap.String("channel", channel.name), zap.Error(err))
continue
}
return
}
writePromptFallback(fallback, req.title, req.message)
}
func writePromptFallback(w io.Writer, title, message string) {
if w == nil {
return
}
if _, err := io.WriteString(w, "错误: "+title+": "+message+"\n"); err != nil {
return
}
}

View File

@@ -0,0 +1,140 @@
package main
import (
"bytes"
"errors"
"fmt"
"os/exec"
"strings"
"testing"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zaptest/observer"
)
type commandCall struct {
timeout time.Duration
env []string
name string
args []string
}
type fakeCommandRunner struct {
paths map[string]bool
runErrs map[string]error
calls []commandCall
}
func (r *fakeCommandRunner) LookPath(file string) (string, error) {
if r.paths[file] {
return "/usr/bin/" + file, nil
}
return "", exec.ErrNotFound
}
func (r *fakeCommandRunner) Run(timeout time.Duration, env []string, name string, args ...string) error {
r.calls = append(r.calls, commandCall{
timeout: timeout,
env: append([]string(nil), env...),
name: name,
args: append([]string(nil), args...),
})
if err := r.runErrs[name]; err != nil {
return err
}
return nil
}
func TestRunPromptPipelineFallbackOrder(t *testing.T) {
var calls []string
channels := []promptChannel{
{
name: "unavailable",
available: func() error {
calls = append(calls, "available-1")
return errors.New("missing")
},
run: func(promptRequest) error {
calls = append(calls, "run-1")
return nil
},
},
{
name: "failed",
available: func() error {
calls = append(calls, "available-2")
return nil
},
run: func(promptRequest) error {
calls = append(calls, "run-2")
return errors.New("failed")
},
},
{
name: "success",
available: func() error {
calls = append(calls, "available-3")
return nil
},
run: func(promptRequest) error {
calls = append(calls, "run-3")
return nil
},
},
}
var fallback bytes.Buffer
runPromptPipeline(promptRequest{title: "Nex 启动失败", message: "启动失败"}, channels, &fallback, zap.NewNop())
want := []string{"available-1", "available-2", "run-2", "available-3", "run-3"}
if fmt.Sprint(calls) != fmt.Sprint(want) {
t.Fatalf("调用顺序 = %v, want %v", calls, want)
}
if fallback.Len() != 0 {
t.Fatalf("成功通道后不应写入 fallback实际: %s", fallback.String())
}
}
func TestRunPromptPipelineWritesFallback(t *testing.T) {
channels := []promptChannel{
{
name: "unavailable",
available: func() error { return errors.New("missing") },
run: func(promptRequest) error { return nil },
},
}
var fallback bytes.Buffer
runPromptPipeline(promptRequest{title: "Nex 启动失败", message: "端口被占用"}, channels, &fallback, zap.NewNop())
want := "错误: Nex 启动失败: 端口被占用\n"
if fallback.String() != want {
t.Fatalf("fallback = %q, want %q", fallback.String(), want)
}
}
func TestReportStartupFailureLogsRedactedError(t *testing.T) {
old := buildPromptChannels
buildPromptChannels = func(commandRunner) []promptChannel {
return []promptChannel{{name: "fake-success", run: func(promptRequest) error { return nil }}}
}
t.Cleanup(func() { buildPromptChannels = old })
core, logs := observer.New(zap.ErrorLevel)
logger := zap.New(core)
err := errors.New("数据库连接失败: nex:secret@tcp(localhost:3306)/nex password=secret api_key=sk-test")
reportStartupFailure(err, logger)
entries := logs.All()
if len(entries) != 1 {
t.Fatalf("应记录 1 条错误日志,实际: %d", len(entries))
}
fields := fmt.Sprint(entries[0].ContextMap())
for _, secret := range []string{"secret", "sk-test"} {
if strings.Contains(fields, secret) {
t.Fatalf("启动失败日志不应包含敏感信息 %q实际: %s", secret, fields)
}
}
}

View File

@@ -0,0 +1,332 @@
package main
import (
"errors"
"fmt"
"net"
"net/http"
"path/filepath"
"sync/atomic"
"testing"
"time"
"nex/backend/internal/config"
"nex/backend/internal/conversion"
"nex/backend/internal/database"
pkgLogger "nex/backend/pkg/logger"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/gorm"
)
type fakeDesktopLock struct {
lockErr error
unlockCount atomic.Int32
}
func (l *fakeDesktopLock) Lock() error {
return l.lockErr
}
func (l *fakeDesktopLock) Unlock() error {
l.unlockCount.Add(1)
return nil
}
func (l *fakeDesktopLock) unlocked() bool {
return l.unlockCount.Load() > 0
}
type recordingListener struct {
net.Listener
closeCount atomic.Int32
}
func newRecordingListener(t *testing.T) *recordingListener {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建测试 listener 失败: %v", err)
}
return &recordingListener{Listener: listener}
}
func (l *recordingListener) Close() error {
l.closeCount.Add(1)
return l.Listener.Close()
}
func (l *recordingListener) closed() bool {
return l.closeCount.Load() > 0
}
func testDesktopConfig(t *testing.T) *config.Config {
t.Helper()
tmpDir := t.TempDir()
cfg := config.DefaultConfig()
cfg.Server.Port = 0
cfg.Database.Driver = "sqlite"
cfg.Database.Path = filepath.Join(tmpDir, "config.db")
cfg.Log.Path = filepath.Join(tmpDir, "log")
return cfg
}
func installDesktopTestHooks(t *testing.T, cfg *config.Config, mutate func(*desktopRuntimeHooks)) {
t.Helper()
oldHooks := desktopHooks
oldServer := server
oldLogger := zapLogger
oldShutdownCtx := shutdownCtx
oldShutdownCancel := shutdownCancel
server = nil
zapLogger = nil
shutdownCtx = nil
shutdownCancel = nil
hooks := defaultDesktopRuntimeHooks()
if cfg != nil {
hooks.loadConfig = func() (*config.Config, config.ConfigMetadata, error) {
return cfg, config.ConfigMetadata{ConfigPath: filepath.Join(t.TempDir(), "config.yaml")}, nil
}
}
hooks.upgradeLogger = func(_ *zap.Logger, _ pkgLogger.Config) (*zap.Logger, error) {
return zap.NewNop(), nil
}
hooks.setupStaticFiles = func(*gin.Engine) error { return nil }
hooks.startServer = func(*http.Server, net.Listener, chan<- error, *zap.Logger) {}
hooks.setupSystray = func(int, <-chan error) error { return nil }
if mutate != nil {
mutate(&hooks)
}
desktopHooks = hooks
t.Cleanup(func() {
if server != nil {
_ = server.Close()
}
desktopHooks = oldHooks
server = oldServer
zapLogger = oldLogger
shutdownCtx = oldShutdownCtx
shutdownCancel = oldShutdownCancel
})
}
func requireStartupPhase(t *testing.T, err error, want startupPhase) {
t.Helper()
if err == nil {
t.Fatalf("期望 %s 阶段启动错误,实际 nil", want)
}
var startupErr *startupError
if !errors.As(err, &startupErr) {
t.Fatalf("期望 startupError实际: %T %v", err, err)
}
if startupErr.phase != want {
t.Fatalf("phase = %s, want %s", startupErr.phase, want)
}
}
func TestRunDesktopConfigFailureReturnsConfigPhase(t *testing.T) {
installDesktopTestHooks(t, nil, func(h *desktopRuntimeHooks) {
h.loadConfig = func() (*config.Config, config.ConfigMetadata, error) {
return nil, config.ConfigMetadata{}, errors.New("yaml 解析失败")
}
})
err := runDesktop(zap.NewNop())
requireStartupPhase(t, err, phaseConfig)
}
func TestRunDesktopSingletonFailurePrecedesPortListen(t *testing.T) {
cfg := testDesktopConfig(t)
lock := &fakeDesktopLock{lockErr: errors.New("已有实例运行")}
listenCalled := false
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
h.newLock = func(string) singletonLocker { return lock }
h.listen = func(int) (net.Listener, error) {
listenCalled = true
return nil, errors.New("不应监听端口")
}
})
err := runDesktop(zap.NewNop())
requireStartupPhase(t, err, phaseSingleton)
if listenCalled {
t.Fatal("单实例锁失败时不应继续监听端口")
}
}
func TestRunDesktopPortFailureUnlocksSingleton(t *testing.T) {
cfg := testDesktopConfig(t)
lock := &fakeDesktopLock{}
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
h.newLock = func(string) singletonLocker { return lock }
h.listen = func(int) (net.Listener, error) { return nil, errors.New("bind failed") }
})
err := runDesktop(zap.NewNop())
requireStartupPhase(t, err, phasePort)
if !lock.unlocked() {
t.Fatal("端口监听失败时应释放单实例锁")
}
}
func TestRunDesktopLoggerFailureClosesListenerAndUnlocks(t *testing.T) {
cfg := testDesktopConfig(t)
lock := &fakeDesktopLock{}
listener := newRecordingListener(t)
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
h.newLock = func(string) singletonLocker { return lock }
h.listen = func(int) (net.Listener, error) { return listener, nil }
h.upgradeLogger = func(*zap.Logger, pkgLogger.Config) (*zap.Logger, error) {
return nil, errors.New("log permission denied")
}
})
err := runDesktop(zap.NewNop())
requireStartupPhase(t, err, phaseLogger)
if !listener.closed() {
t.Fatal("日志初始化失败时应关闭 listener")
}
if !lock.unlocked() {
t.Fatal("日志初始化失败时应释放单实例锁")
}
}
func TestRunDesktopDatabaseFailureClassification(t *testing.T) {
tests := []struct {
name string
err error
want startupPhase
}{
{name: "database", err: errors.New("open failed"), want: phaseDatabase},
{name: "migration", err: fmt.Errorf("%w: %w", database.ErrMigration, errors.New("goose failed")), want: phaseMigration},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := testDesktopConfig(t)
lock := &fakeDesktopLock{}
listener := newRecordingListener(t)
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
h.newLock = func(string) singletonLocker { return lock }
h.listen = func(int) (net.Listener, error) { return listener, nil }
h.initDB = func(*config.DatabaseConfig, *zap.Logger) (*gorm.DB, error) { return nil, tt.err }
})
err := runDesktop(zap.NewNop())
requireStartupPhase(t, err, tt.want)
if !listener.closed() {
t.Fatal("数据库失败时应关闭 listener")
}
if !lock.unlocked() {
t.Fatal("数据库失败时应释放单实例锁")
}
})
}
}
func TestRunDesktopInternalStartupFailurePhasesAndDatabaseCleanup(t *testing.T) {
tests := []struct {
name string
mutate func(*desktopRuntimeHooks)
want startupPhase
}{
{
name: "adapter",
mutate: func(h *desktopRuntimeHooks) {
h.registerAdapters = func(conversion.AdapterRegistry) error { return errors.New("adapter failed") }
},
want: phaseAdapter,
},
{
name: "static",
mutate: func(h *desktopRuntimeHooks) {
h.setupStaticFiles = func(*gin.Engine) error { return errors.New("missing frontend") }
},
want: phaseStaticResource,
},
{
name: "server",
mutate: func(h *desktopRuntimeHooks) {
h.startServer = func(_ *http.Server, _ net.Listener, errCh chan<- error, _ *zap.Logger) {
errCh <- errors.New("serve failed")
}
},
want: phaseServer,
},
{
name: "tray",
mutate: func(h *desktopRuntimeHooks) {
h.setupSystray = func(int, <-chan error) error {
return newStartupError(phaseTray, "托盘初始化失败", errors.New("tray failed"))
}
},
want: phaseTray,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := testDesktopConfig(t)
lock := &fakeDesktopLock{}
listener := newRecordingListener(t)
closeDBCalled := false
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
h.newLock = func(string) singletonLocker { return lock }
h.listen = func(int) (net.Listener, error) { return listener, nil }
h.closeDB = func(db *gorm.DB) {
closeDBCalled = true
database.Close(db)
}
tt.mutate(h)
})
err := runDesktop(zap.NewNop())
requireStartupPhase(t, err, tt.want)
if !closeDBCalled {
t.Fatal("数据库初始化后的启动失败应关闭数据库")
}
if !listener.closed() {
t.Fatal("数据库初始化后的启动失败应关闭 listener")
}
if !lock.unlocked() {
t.Fatal("数据库初始化后的启动失败应释放单实例锁")
}
})
}
}
func TestRunDesktopBrowserFailureRemainsNonFatal(t *testing.T) {
controller := newFakeTrayController()
notified := make(chan string, 1)
controller.run = func(onReady func(), _ func()) {
onReady()
<-controller.quitCh
}
err := runSystray(19826, trayOptions{
controller: controller,
readyTimeout: time.Second,
iconLoader: func() ([]byte, error) { return []byte("icon"), nil },
openBrowser: func(string) error { return errors.New("no browser") },
notify: func(_, message string) {
notified <- message
controller.Quit()
},
logger: zap.NewNop(),
})
if err != nil {
t.Fatalf("浏览器打开失败不应导致 runSystray 返回 fatal: %v", err)
}
if got := <-notified; got == "" {
t.Fatal("浏览器打开失败应提示用户手动访问")
}
}

View File

@@ -0,0 +1,96 @@
package main
import (
"fmt"
"regexp"
)
type startupPhase string
const (
phaseConfig startupPhase = "config"
phaseSingleton startupPhase = "singleton"
phasePort startupPhase = "port"
phaseLogger startupPhase = "logger"
phaseDatabase startupPhase = "database"
phaseMigration startupPhase = "migration"
phaseAdapter startupPhase = "adapter"
phaseStaticResource startupPhase = "static"
phaseServer startupPhase = "server"
phaseTray startupPhase = "tray"
)
type startupError struct {
phase startupPhase
message string
cause error
}
func newStartupError(phase startupPhase, message string, cause error) *startupError {
return &startupError{
phase: phase,
message: redactSensitive(message),
cause: cause,
}
}
func (e *startupError) Error() string {
if e == nil {
return ""
}
if e.cause == nil {
return fmt.Sprintf("%s: %s", e.phase, e.message)
}
return fmt.Sprintf("%s: %s: %s", e.phase, e.message, redactSensitive(e.cause.Error()))
}
func (e *startupError) Unwrap() error {
if e == nil {
return nil
}
return e.cause
}
func (e *startupError) Phase() string {
if e == nil {
return ""
}
return string(e.phase)
}
func (e *startupError) UserMessage() string {
if e == nil {
return ""
}
return redactSensitive(e.message)
}
var sensitiveReplacers = []struct {
pattern *regexp.Regexp
replacement string
}{
{regexp.MustCompile(`(?i)(password\s*[:=]\s*)[^\s,;&]+`), `${1}<redacted>`},
{regexp.MustCompile(`(?i)(api[_-]?key\s*[:=]\s*)[^\s,;&]+`), `${1}<redacted>`},
{regexp.MustCompile(`(?i)(secret\s*[:=]\s*)[^\s,;&]+`), `${1}<redacted>`},
{regexp.MustCompile(`([^\s:/]+):([^\s@]+)@tcp\(`), `${1}:<redacted>@tcp(`},
{regexp.MustCompile(`(://[^\s:/]+):([^\s@]+)@`), `${1}:<redacted>@`},
}
func redactSensitive(s string) string {
for _, replacer := range sensitiveReplacers {
s = replacer.pattern.ReplaceAllString(s, replacer.replacement)
}
return s
}
func startupTitle() string {
return appName + " 启动失败"
}
func startupServerErrorMessage() string {
return "后端服务启动失败\n\n请检查端口占用、网络权限或查看日志获取更多信息"
}
func startupInternalErrorMessage() string {
return "应用初始化失败\n\n请查看日志或重新安装应用"
}

View File

@@ -0,0 +1,40 @@
package main
import (
"errors"
"strings"
"testing"
)
func TestStartupErrorContainsPhaseAndCause(t *testing.T) {
cause := errors.New("底层失败")
err := newStartupError(phaseDatabase, "数据库初始化失败", cause)
if err.Phase() != "database" {
t.Fatalf("phase = %q, want database", err.Phase())
}
if !errors.Is(err, cause) {
t.Fatal("startupError 应保留底层 cause")
}
if !strings.Contains(err.Error(), "database") {
t.Fatalf("错误字符串应包含 phase实际: %s", err.Error())
}
}
func TestStartupErrorRedactsSensitiveUserMessage(t *testing.T) {
message := "数据库初始化失败: nex:secret@tcp(localhost:3306)/nex password=secret api_key=sk-test"
err := newStartupError(phaseDatabase, message, errors.New("cause password=secret api_key=sk-test"))
userMessage := err.UserMessage()
for _, secret := range []string{"secret", "sk-test"} {
if strings.Contains(userMessage, secret) {
t.Fatalf("用户提示不应包含敏感信息 %q实际: %s", secret, userMessage)
}
if strings.Contains(err.Error(), secret) {
t.Fatalf("日志错误字符串不应包含敏感信息 %q实际: %s", secret, err.Error())
}
}
if !strings.Contains(userMessage, "<redacted>") {
t.Fatalf("用户提示应包含脱敏占位符,实际: %s", userMessage)
}
}

View File

@@ -13,14 +13,15 @@ import (
func TestSetupStaticFiles(t *testing.T) {
gin.SetMode(gin.TestMode)
distFS, err := frontendDistFS()
if err != nil {
t.Skipf("跳过测试: 前端资源未构建: %v", err)
return
}
r := gin.New()
setupStaticFilesWithFS(r, distFS)
setupStaticFilesWithFS(r, fstest.MapFS{
"index.html": {Data: []byte("<html>fallback</html>")},
"icon.png": {Data: []byte("png")},
"assets/test.js": {Data: []byte("console.log('test')")},
"assets/test.css": {Data: []byte("body {}")},
"assets/test.svg": {Data: []byte("<svg></svg>")},
"assets/test.woff": {Data: []byte("font")},
})
t.Run("API 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
@@ -73,13 +74,12 @@ func TestSetupStaticFiles(t *testing.T) {
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code == 200 {
expected := "application/javascript"
if w.Header().Get("Content-Type") != expected {
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
}
} else {
t.Log("文件不存在,跳过 MIME 类型验证")
if w.Code != http.StatusOK {
t.Fatalf("期望状态码 200, 实际 %d", w.Code)
}
expected := "application/javascript"
if w.Header().Get("Content-Type") != expected {
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
}
})
@@ -88,13 +88,12 @@ func TestSetupStaticFiles(t *testing.T) {
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code == 200 {
expected := "text/css"
if w.Header().Get("Content-Type") != expected {
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
}
} else {
t.Log("文件不存在,跳过 MIME 类型验证")
if w.Code != http.StatusOK {
t.Fatalf("期望状态码 200, 实际 %d", w.Code)
}
expected := "text/css"
if w.Header().Get("Content-Type") != expected {
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
}
})
@@ -128,12 +127,6 @@ func TestSetupStaticFilesWithFS_IconPNG(t *testing.T) {
func TestWithProtocolAndStaticRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
distFS, err := frontendDistFS()
if err != nil {
t.Skipf("跳过测试: 前端资源未构建: %v", err)
return
}
r := gin.New()
var gotProtocol string
@@ -148,7 +141,10 @@ func TestWithProtocolAndStaticRoutes(t *testing.T) {
gotPath = c.Param("path")
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
}))
setupStaticFilesWithFS(r, distFS)
setupStaticFilesWithFS(r, fstest.MapFS{
"index.html": {Data: []byte("<html>fallback</html>")},
"assets/test.js": {Data: []byte("console.log('test')")},
})
t.Run("OpenAI route enters proxy handler wrapper", func(t *testing.T) {
gotProtocol = ""
@@ -199,14 +195,11 @@ func TestWithProtocolAndStaticRoutes(t *testing.T) {
if gotProtocol != "" || gotPath != "" {
t.Errorf("静态资源不应进入代理包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
}
if w.Code == http.StatusOK {
if !strings.HasPrefix(w.Header().Get("Content-Type"), "application/javascript") {
t.Errorf("期望 JS Content-Type, 实际 %s", w.Header().Get("Content-Type"))
}
return
if w.Code != http.StatusOK {
t.Fatalf("期望静态资源返回 200, 实际 %d", w.Code)
}
if w.Code != http.StatusNotFound {
t.Errorf("期望静态资源返回 200 或 404, 实际 %d", w.Code)
if !strings.HasPrefix(w.Header().Get("Content-Type"), "application/javascript") {
t.Errorf("期望 JS Content-Type, 实际 %s", w.Header().Get("Content-Type"))
}
})

231
backend/cmd/desktop/tray.go Normal file
View File

@@ -0,0 +1,231 @@
package main
import (
"fmt"
"runtime"
"sync"
"time"
"nex/embedfs"
"github.com/getlantern/systray"
"go.uber.org/zap"
)
const defaultTrayReadyTimeout = 5 * time.Second
type trayMenuItem interface {
Disable()
Clicked() <-chan struct{}
}
type trayController interface {
Run(onReady func(), onExit func())
Quit()
SetIcon(icon []byte)
SetTooltip(tooltip string)
AddMenuItem(title, tooltip string) trayMenuItem
AddSeparator()
}
type realTrayController struct{}
func (realTrayController) Run(onReady func(), onExit func()) {
systray.Run(onReady, onExit)
}
func (realTrayController) Quit() {
systray.Quit()
}
func (realTrayController) SetIcon(icon []byte) {
systray.SetIcon(icon)
}
func (realTrayController) SetTooltip(tooltip string) {
systray.SetTooltip(tooltip)
}
func (realTrayController) AddMenuItem(title, tooltip string) trayMenuItem {
return realTrayMenuItem{item: systray.AddMenuItem(title, tooltip)}
}
func (realTrayController) AddSeparator() {
systray.AddSeparator()
}
type realTrayMenuItem struct {
item *systray.MenuItem
}
func (m realTrayMenuItem) Disable() {
m.item.Disable()
}
func (m realTrayMenuItem) Clicked() <-chan struct{} {
return m.item.ClickedCh
}
type trayOptions struct {
controller trayController
readyTimeout time.Duration
iconLoader func() ([]byte, error)
openBrowser func(string) error
notify func(string, string)
logger *zap.Logger
fatalErrCh <-chan error
}
func setupSystray(port int, fatalErrCh <-chan error) error {
return runSystray(port, trayOptions{
controller: realTrayController{},
readyTimeout: defaultTrayReadyTimeout,
iconLoader: loadTrayIcon,
openBrowser: openBrowser,
notify: showError,
logger: dialogLogger(),
fatalErrCh: fatalErrCh,
})
}
func runSystray(port int, opts trayOptions) error {
if opts.controller == nil {
opts.controller = realTrayController{}
}
if opts.readyTimeout <= 0 {
opts.readyTimeout = defaultTrayReadyTimeout
}
if opts.iconLoader == nil {
opts.iconLoader = loadTrayIcon
}
if opts.openBrowser == nil {
opts.openBrowser = openBrowser
}
if opts.notify == nil {
opts.notify = showError
}
if opts.logger == nil {
opts.logger = dialogLogger()
}
readyCh := make(chan struct{})
doneCh := make(chan struct{})
errCh := make(chan error, 1)
var readyOnce sync.Once
var errOnce sync.Once
signalReady := func() {
readyOnce.Do(func() { close(readyCh) })
}
signalError := func(err error) {
errOnce.Do(func() { errCh <- err })
}
go monitorTrayStartup(port, opts, readyCh, doneCh, signalError)
opts.controller.Run(func() {
handleTrayReady(port, opts, signalReady, signalError)
}, nil)
close(doneCh)
select {
case err := <-errCh:
return err
default:
return nil
}
}
func monitorTrayStartup(port int, opts trayOptions, readyCh <-chan struct{}, doneCh <-chan struct{}, signalError func(error)) {
timer := time.NewTimer(opts.readyTimeout)
defer timer.Stop()
ready := false
for {
select {
case <-readyCh:
ready = true
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
openDesktopBrowser(port, opts)
readyCh = nil
case <-timer.C:
if !ready {
signalError(newStartupError(phaseTray, "托盘初始化超时", fmt.Errorf("托盘未在 %s 内 ready", opts.readyTimeout)))
opts.controller.Quit()
}
case err := <-opts.fatalErrCh:
if err != nil {
signalError(newStartupError(phaseServer, startupServerErrorMessage(), err))
opts.controller.Quit()
}
case <-doneCh:
return
}
}
}
func handleTrayReady(port int, opts trayOptions, signalReady func(), signalError func(error)) {
defer func() {
if recovered := recover(); recovered != nil {
err := fmt.Errorf("托盘初始化 panic: %v", recovered)
signalError(newStartupError(phaseTray, "托盘菜单初始化失败", err))
opts.controller.Quit()
}
}()
icon, err := opts.iconLoader()
if err != nil {
signalError(newStartupError(phaseTray, "托盘图标资源无法加载", err))
opts.controller.Quit()
return
}
opts.controller.SetIcon(icon)
opts.controller.SetTooltip(appTooltip)
mOpen := opts.controller.AddMenuItem("打开管理界面", "在浏览器中打开")
opts.controller.AddSeparator()
mStatus := opts.controller.AddMenuItem("状态: 运行中", "")
mStatus.Disable()
mPort := opts.controller.AddMenuItem(desktopPortMenuTitle(port), "")
mPort.Disable()
opts.controller.AddSeparator()
mQuit := opts.controller.AddMenuItem("退出", "停止服务并退出")
go func() {
for {
select {
case <-mOpen.Clicked():
if err := opts.openBrowser(desktopURL(port)); err != nil {
opts.logger.Warn("打开浏览器失败", zap.Error(err))
}
case <-mQuit.Clicked():
doShutdown()
opts.controller.Quit()
return
}
}
}()
signalReady()
}
func openDesktopBrowser(port int, opts trayOptions) {
url := desktopURL(port)
if err := opts.openBrowser(url); err != nil {
opts.logger.Warn("无法打开浏览器", zap.Error(err))
opts.notify(appName, fmt.Sprintf("无法自动打开浏览器,请手动访问 %s", url))
}
}
func loadTrayIcon() ([]byte, error) {
if runtime.GOOS == "windows" {
return embedfs.Assets.ReadFile("assets/icon.ico")
}
return embedfs.Assets.ReadFile("assets/icon.png")
}

View File

@@ -0,0 +1,169 @@
package main
import (
"errors"
"sync"
"testing"
"time"
"go.uber.org/zap"
)
type fakeTrayController struct {
run func(onReady func(), onExit func())
quitCh chan struct{}
quitOnce sync.Once
icon []byte
tooltip string
menuItems []*fakeTrayMenuItem
}
func newFakeTrayController() *fakeTrayController {
return &fakeTrayController{quitCh: make(chan struct{})}
}
func (c *fakeTrayController) Run(onReady func(), onExit func()) {
if c.run != nil {
c.run(onReady, onExit)
return
}
onReady()
<-c.quitCh
if onExit != nil {
onExit()
}
}
func (c *fakeTrayController) Quit() {
c.quitOnce.Do(func() { close(c.quitCh) })
}
func (c *fakeTrayController) SetIcon(icon []byte) {
c.icon = append([]byte(nil), icon...)
}
func (c *fakeTrayController) SetTooltip(tooltip string) {
c.tooltip = tooltip
}
func (c *fakeTrayController) AddMenuItem(title, tooltip string) trayMenuItem {
item := &fakeTrayMenuItem{clicked: make(chan struct{}), title: title, tooltip: tooltip}
c.menuItems = append(c.menuItems, item)
return item
}
func (c *fakeTrayController) AddSeparator() {}
type fakeTrayMenuItem struct {
clicked chan struct{}
title string
tooltip string
disabled bool
}
func (m *fakeTrayMenuItem) Disable() {
m.disabled = true
}
func (m *fakeTrayMenuItem) Clicked() <-chan struct{} {
return m.clicked
}
func TestRunSystrayReadyOpensBrowser(t *testing.T) {
controller := newFakeTrayController()
opened := make(chan string, 1)
err := runSystray(19826, trayOptions{
controller: controller,
readyTimeout: time.Second,
iconLoader: func() ([]byte, error) { return []byte("icon"), nil },
openBrowser: func(url string) error {
opened <- url
controller.Quit()
return nil
},
notify: func(string, string) {},
logger: zap.NewNop(),
})
if err != nil {
t.Fatalf("托盘 ready 成功不应返回错误: %v", err)
}
if got := <-opened; got != "http://localhost:19826" {
t.Fatalf("浏览器 URL = %s", got)
}
if string(controller.icon) != "icon" {
t.Fatalf("应设置托盘图标")
}
if controller.tooltip != appTooltip {
t.Fatalf("tooltip = %q, want %q", controller.tooltip, appTooltip)
}
}
func TestRunSystrayReadyTimeoutReturnsTrayStartupError(t *testing.T) {
controller := newFakeTrayController()
controller.run = func(_ func(), _ func()) {
<-controller.quitCh
}
err := runSystray(19826, trayOptions{
controller: controller,
readyTimeout: 10 * time.Millisecond,
iconLoader: func() ([]byte, error) { return []byte("icon"), nil },
openBrowser: func(string) error { return nil },
notify: func(string, string) {},
logger: zap.NewNop(),
})
if err == nil {
t.Fatal("托盘 ready timeout 应返回错误")
}
var startupErr *startupError
if !errors.As(err, &startupErr) || startupErr.Phase() != "tray" {
t.Fatalf("应返回 tray 阶段启动错误,实际: %v", err)
}
}
func TestRunSystrayIconLoadFailureReturnsTrayStartupError(t *testing.T) {
controller := newFakeTrayController()
err := runSystray(19826, trayOptions{
controller: controller,
readyTimeout: time.Second,
iconLoader: func() ([]byte, error) { return nil, errors.New("missing icon") },
openBrowser: func(string) error { return nil },
notify: func(string, string) {},
logger: zap.NewNop(),
})
if err == nil {
t.Fatal("托盘图标加载失败应返回错误")
}
var startupErr *startupError
if !errors.As(err, &startupErr) || startupErr.Phase() != "tray" {
t.Fatalf("应返回 tray 阶段启动错误,实际: %v", err)
}
}
func TestRunSystrayBrowserOpenFailureIsNonFatal(t *testing.T) {
controller := newFakeTrayController()
notified := make(chan string, 1)
err := runSystray(19826, trayOptions{
controller: controller,
readyTimeout: time.Second,
iconLoader: func() ([]byte, error) { return []byte("icon"), nil },
openBrowser: func(string) error { return errors.New("no browser") },
notify: func(_, message string) {
notified <- message
controller.Quit()
},
logger: zap.NewNop(),
})
if err != nil {
t.Fatalf("浏览器打开失败不应成为 fatal: %v", err)
}
if got := <-notified; got == "" {
t.Fatal("浏览器打开失败应提示用户")
}
}

View File

@@ -2,6 +2,7 @@ package database
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
@@ -17,6 +18,8 @@ import (
pkglogger "nex/backend/pkg/logger"
)
var ErrMigration = errors.New("数据库迁移失败")
func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
moduleLogger := pkglogger.WithModule(zapLogger, "database")
@@ -26,7 +29,7 @@ func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
}
if err := runMigrations(db, cfg.Driver, moduleLogger); err != nil {
return nil, fmt.Errorf("数据库迁移失败: %w", err)
return nil, fmt.Errorf("%w: %w", ErrMigration, err)
}
configurePool(db, cfg, moduleLogger)