141 lines
3.4 KiB
Go
141 lines
3.4 KiB
Go
package main
|
||
|
||
import (
|
||
"bytes"
|
||
"errors"
|
||
"fmt"
|
||
"os/exec"
|
||
"strings"
|
||
"testing"
|
||
"time"
|
||
|
||
"go.uber.org/zap"
|
||
"go.uber.org/zap/zaptest/observer"
|
||
)
|
||
|
||
type commandCall struct {
|
||
timeout time.Duration
|
||
env []string
|
||
name string
|
||
args []string
|
||
}
|
||
|
||
type fakeCommandRunner struct {
|
||
paths map[string]bool
|
||
runErrs map[string]error
|
||
calls []commandCall
|
||
}
|
||
|
||
func (r *fakeCommandRunner) LookPath(file string) (string, error) {
|
||
if r.paths[file] {
|
||
return "/usr/bin/" + file, nil
|
||
}
|
||
return "", exec.ErrNotFound
|
||
}
|
||
|
||
func (r *fakeCommandRunner) Run(timeout time.Duration, env []string, name string, args ...string) error {
|
||
r.calls = append(r.calls, commandCall{
|
||
timeout: timeout,
|
||
env: append([]string(nil), env...),
|
||
name: name,
|
||
args: append([]string(nil), args...),
|
||
})
|
||
if err := r.runErrs[name]; err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func TestRunPromptPipelineFallbackOrder(t *testing.T) {
|
||
var calls []string
|
||
channels := []promptChannel{
|
||
{
|
||
name: "unavailable",
|
||
available: func() error {
|
||
calls = append(calls, "available-1")
|
||
return errors.New("missing")
|
||
},
|
||
run: func(promptRequest) error {
|
||
calls = append(calls, "run-1")
|
||
return nil
|
||
},
|
||
},
|
||
{
|
||
name: "failed",
|
||
available: func() error {
|
||
calls = append(calls, "available-2")
|
||
return nil
|
||
},
|
||
run: func(promptRequest) error {
|
||
calls = append(calls, "run-2")
|
||
return errors.New("failed")
|
||
},
|
||
},
|
||
{
|
||
name: "success",
|
||
available: func() error {
|
||
calls = append(calls, "available-3")
|
||
return nil
|
||
},
|
||
run: func(promptRequest) error {
|
||
calls = append(calls, "run-3")
|
||
return nil
|
||
},
|
||
},
|
||
}
|
||
|
||
var fallback bytes.Buffer
|
||
runPromptPipeline(promptRequest{title: "Nex 启动失败", message: "启动失败"}, channels, &fallback, zap.NewNop())
|
||
|
||
want := []string{"available-1", "available-2", "run-2", "available-3", "run-3"}
|
||
if fmt.Sprint(calls) != fmt.Sprint(want) {
|
||
t.Fatalf("调用顺序 = %v, want %v", calls, want)
|
||
}
|
||
if fallback.Len() != 0 {
|
||
t.Fatalf("成功通道后不应写入 fallback,实际: %s", fallback.String())
|
||
}
|
||
}
|
||
|
||
func TestRunPromptPipelineWritesFallback(t *testing.T) {
|
||
channels := []promptChannel{
|
||
{
|
||
name: "unavailable",
|
||
available: func() error { return errors.New("missing") },
|
||
run: func(promptRequest) error { return nil },
|
||
},
|
||
}
|
||
|
||
var fallback bytes.Buffer
|
||
runPromptPipeline(promptRequest{title: "Nex 启动失败", message: "端口被占用"}, channels, &fallback, zap.NewNop())
|
||
|
||
want := "错误: Nex 启动失败: 端口被占用\n"
|
||
if fallback.String() != want {
|
||
t.Fatalf("fallback = %q, want %q", fallback.String(), want)
|
||
}
|
||
}
|
||
|
||
func TestReportStartupFailureLogsRedactedError(t *testing.T) {
|
||
old := buildPromptChannels
|
||
buildPromptChannels = func(commandRunner) []promptChannel {
|
||
return []promptChannel{{name: "fake-success", run: func(promptRequest) error { return nil }}}
|
||
}
|
||
t.Cleanup(func() { buildPromptChannels = old })
|
||
|
||
core, logs := observer.New(zap.ErrorLevel)
|
||
logger := zap.New(core)
|
||
err := errors.New("数据库连接失败: nex:secret@tcp(localhost:3306)/nex password=secret api_key=sk-test")
|
||
|
||
reportStartupFailure(err, logger)
|
||
|
||
entries := logs.All()
|
||
if len(entries) != 1 {
|
||
t.Fatalf("应记录 1 条错误日志,实际: %d", len(entries))
|
||
}
|
||
fields := fmt.Sprint(entries[0].ContextMap())
|
||
for _, secret := range []string{"secret", "sk-test"} {
|
||
if strings.Contains(fields, secret) {
|
||
t.Fatalf("启动失败日志不应包含敏感信息 %q,实际: %s", secret, fields)
|
||
}
|
||
}
|
||
}
|