1
0
Files
nex/backend/cmd/desktop/run_desktop_test.go

333 lines
8.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"errors"
"fmt"
"net"
"net/http"
"path/filepath"
"sync/atomic"
"testing"
"time"
"nex/backend/internal/config"
"nex/backend/internal/conversion"
"nex/backend/internal/database"
pkgLogger "nex/backend/pkg/logger"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/gorm"
)
type fakeDesktopLock struct {
lockErr error
unlockCount atomic.Int32
}
func (l *fakeDesktopLock) Lock() error {
return l.lockErr
}
func (l *fakeDesktopLock) Unlock() error {
l.unlockCount.Add(1)
return nil
}
func (l *fakeDesktopLock) unlocked() bool {
return l.unlockCount.Load() > 0
}
type recordingListener struct {
net.Listener
closeCount atomic.Int32
}
func newRecordingListener(t *testing.T) *recordingListener {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建测试 listener 失败: %v", err)
}
return &recordingListener{Listener: listener}
}
func (l *recordingListener) Close() error {
l.closeCount.Add(1)
return l.Listener.Close()
}
func (l *recordingListener) closed() bool {
return l.closeCount.Load() > 0
}
func testDesktopConfig(t *testing.T) *config.Config {
t.Helper()
tmpDir := t.TempDir()
cfg := config.DefaultConfig()
cfg.Server.Port = 0
cfg.Database.Driver = "sqlite"
cfg.Database.Path = filepath.Join(tmpDir, "config.db")
cfg.Log.Path = filepath.Join(tmpDir, "log")
return cfg
}
func installDesktopTestHooks(t *testing.T, cfg *config.Config, mutate func(*desktopRuntimeHooks)) {
t.Helper()
oldHooks := desktopHooks
oldServer := server
oldLogger := zapLogger
oldShutdownCtx := shutdownCtx
oldShutdownCancel := shutdownCancel
server = nil
zapLogger = nil
shutdownCtx = nil
shutdownCancel = nil
hooks := defaultDesktopRuntimeHooks()
if cfg != nil {
hooks.loadConfig = func() (*config.Config, config.ConfigMetadata, error) {
return cfg, config.ConfigMetadata{ConfigPath: filepath.Join(t.TempDir(), "config.yaml")}, nil
}
}
hooks.upgradeLogger = func(_ *zap.Logger, _ pkgLogger.Config) (*zap.Logger, error) {
return zap.NewNop(), nil
}
hooks.setupStaticFiles = func(*gin.Engine) error { return nil }
hooks.startServer = func(*http.Server, net.Listener, chan<- error, *zap.Logger) {}
hooks.setupSystray = func(int, <-chan error) error { return nil }
if mutate != nil {
mutate(&hooks)
}
desktopHooks = hooks
t.Cleanup(func() {
if server != nil {
_ = server.Close()
}
desktopHooks = oldHooks
server = oldServer
zapLogger = oldLogger
shutdownCtx = oldShutdownCtx
shutdownCancel = oldShutdownCancel
})
}
func requireStartupPhase(t *testing.T, err error, want startupPhase) {
t.Helper()
if err == nil {
t.Fatalf("期望 %s 阶段启动错误,实际 nil", want)
}
var startupErr *startupError
if !errors.As(err, &startupErr) {
t.Fatalf("期望 startupError实际: %T %v", err, err)
}
if startupErr.phase != want {
t.Fatalf("phase = %s, want %s", startupErr.phase, want)
}
}
func TestRunDesktopConfigFailureReturnsConfigPhase(t *testing.T) {
installDesktopTestHooks(t, nil, func(h *desktopRuntimeHooks) {
h.loadConfig = func() (*config.Config, config.ConfigMetadata, error) {
return nil, config.ConfigMetadata{}, errors.New("yaml 解析失败")
}
})
err := runDesktop(zap.NewNop())
requireStartupPhase(t, err, phaseConfig)
}
func TestRunDesktopSingletonFailurePrecedesPortListen(t *testing.T) {
cfg := testDesktopConfig(t)
lock := &fakeDesktopLock{lockErr: errors.New("已有实例运行")}
listenCalled := false
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
h.newLock = func(string) singletonLocker { return lock }
h.listen = func(int) (net.Listener, error) {
listenCalled = true
return nil, errors.New("不应监听端口")
}
})
err := runDesktop(zap.NewNop())
requireStartupPhase(t, err, phaseSingleton)
if listenCalled {
t.Fatal("单实例锁失败时不应继续监听端口")
}
}
func TestRunDesktopPortFailureUnlocksSingleton(t *testing.T) {
cfg := testDesktopConfig(t)
lock := &fakeDesktopLock{}
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
h.newLock = func(string) singletonLocker { return lock }
h.listen = func(int) (net.Listener, error) { return nil, errors.New("bind failed") }
})
err := runDesktop(zap.NewNop())
requireStartupPhase(t, err, phasePort)
if !lock.unlocked() {
t.Fatal("端口监听失败时应释放单实例锁")
}
}
func TestRunDesktopLoggerFailureClosesListenerAndUnlocks(t *testing.T) {
cfg := testDesktopConfig(t)
lock := &fakeDesktopLock{}
listener := newRecordingListener(t)
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
h.newLock = func(string) singletonLocker { return lock }
h.listen = func(int) (net.Listener, error) { return listener, nil }
h.upgradeLogger = func(*zap.Logger, pkgLogger.Config) (*zap.Logger, error) {
return nil, errors.New("log permission denied")
}
})
err := runDesktop(zap.NewNop())
requireStartupPhase(t, err, phaseLogger)
if !listener.closed() {
t.Fatal("日志初始化失败时应关闭 listener")
}
if !lock.unlocked() {
t.Fatal("日志初始化失败时应释放单实例锁")
}
}
func TestRunDesktopDatabaseFailureClassification(t *testing.T) {
tests := []struct {
name string
err error
want startupPhase
}{
{name: "database", err: errors.New("open failed"), want: phaseDatabase},
{name: "migration", err: fmt.Errorf("%w: %w", database.ErrMigration, errors.New("goose failed")), want: phaseMigration},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := testDesktopConfig(t)
lock := &fakeDesktopLock{}
listener := newRecordingListener(t)
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
h.newLock = func(string) singletonLocker { return lock }
h.listen = func(int) (net.Listener, error) { return listener, nil }
h.initDB = func(*config.DatabaseConfig, *zap.Logger) (*gorm.DB, error) { return nil, tt.err }
})
err := runDesktop(zap.NewNop())
requireStartupPhase(t, err, tt.want)
if !listener.closed() {
t.Fatal("数据库失败时应关闭 listener")
}
if !lock.unlocked() {
t.Fatal("数据库失败时应释放单实例锁")
}
})
}
}
func TestRunDesktopInternalStartupFailurePhasesAndDatabaseCleanup(t *testing.T) {
tests := []struct {
name string
mutate func(*desktopRuntimeHooks)
want startupPhase
}{
{
name: "adapter",
mutate: func(h *desktopRuntimeHooks) {
h.registerAdapters = func(conversion.AdapterRegistry) error { return errors.New("adapter failed") }
},
want: phaseAdapter,
},
{
name: "static",
mutate: func(h *desktopRuntimeHooks) {
h.setupStaticFiles = func(*gin.Engine) error { return errors.New("missing frontend") }
},
want: phaseStaticResource,
},
{
name: "server",
mutate: func(h *desktopRuntimeHooks) {
h.startServer = func(_ *http.Server, _ net.Listener, errCh chan<- error, _ *zap.Logger) {
errCh <- errors.New("serve failed")
}
},
want: phaseServer,
},
{
name: "tray",
mutate: func(h *desktopRuntimeHooks) {
h.setupSystray = func(int, <-chan error) error {
return newStartupError(phaseTray, "托盘初始化失败", errors.New("tray failed"))
}
},
want: phaseTray,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := testDesktopConfig(t)
lock := &fakeDesktopLock{}
listener := newRecordingListener(t)
closeDBCalled := false
installDesktopTestHooks(t, cfg, func(h *desktopRuntimeHooks) {
h.newLock = func(string) singletonLocker { return lock }
h.listen = func(int) (net.Listener, error) { return listener, nil }
h.closeDB = func(db *gorm.DB) {
closeDBCalled = true
database.Close(db)
}
tt.mutate(h)
})
err := runDesktop(zap.NewNop())
requireStartupPhase(t, err, tt.want)
if !closeDBCalled {
t.Fatal("数据库初始化后的启动失败应关闭数据库")
}
if !listener.closed() {
t.Fatal("数据库初始化后的启动失败应关闭 listener")
}
if !lock.unlocked() {
t.Fatal("数据库初始化后的启动失败应释放单实例锁")
}
})
}
}
func TestRunDesktopBrowserFailureRemainsNonFatal(t *testing.T) {
controller := newFakeTrayController()
notified := make(chan string, 1)
controller.run = func(onReady func(), _ func()) {
onReady()
<-controller.quitCh
}
err := runSystray(19826, trayOptions{
controller: controller,
readyTimeout: time.Second,
iconLoader: func() ([]byte, error) { return []byte("icon"), nil },
openBrowser: func(string) error { return errors.New("no browser") },
notify: func(_, message string) {
notified <- message
controller.Quit()
},
logger: zap.NewNop(),
})
if err != nil {
t.Fatalf("浏览器打开失败不应导致 runSystray 返回 fatal: %v", err)
}
if got := <-notified; got == "" {
t.Fatal("浏览器打开失败应提示用户手动访问")
}
}