1
0

fix: 完善转换代理行为

This commit is contained in:
2026-04-26 21:48:17 +08:00
parent 155244433f
commit 9622d44aac
33 changed files with 1127 additions and 1117 deletions

View File

@@ -3,31 +3,60 @@
package main
import (
"errors"
"fmt"
"syscall"
"unsafe"
"go.uber.org/zap"
)
const (
MB_ICONERROR = 0x10
MB_ICONINFORMATION = 0x40
mbIconError = 0x10
mbIconInformation = 0x40
)
var (
user32 = syscall.NewLazyDLL("user32.dll")
procMessageBoxW = user32.NewProc("MessageBoxW")
callMessageBoxW = func(hwnd, text, caption, flags uintptr) (uintptr, error) {
ret, _, err := procMessageBoxW.Call(hwnd, text, caption, flags)
return ret, err
}
)
func showError(title, message string) {
messageBox(title, message, MB_ICONERROR)
if err := messageBox(title, message, mbIconError); err != nil {
if zapLogger != nil {
zapLogger.Warn("显示错误对话框失败", zap.Error(err))
}
}
}
func messageBox(title, message string, flags uint) {
titlePtr, _ := syscall.UTF16PtrFromString(title)
messagePtr, _ := syscall.UTF16PtrFromString(message)
procMessageBoxW.Call(
func messageBox(title, message string, flags uint) error {
titlePtr, err := syscall.UTF16PtrFromString(title)
if err != nil {
return err
}
messagePtr, err := syscall.UTF16PtrFromString(message)
if err != nil {
return err
}
ret, callErr := callMessageBoxW(
0,
uintptr(unsafe.Pointer(messagePtr)),
uintptr(unsafe.Pointer(titlePtr)),
uintptr(flags),
)
if ret != 0 {
return nil
}
if callErr != nil && !errors.Is(callErr, syscall.Errno(0)) {
return callErr
}
return fmt.Errorf("MessageBoxW 调用失败")
}

View File

@@ -168,7 +168,8 @@ func main() {
}
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
r.Any("/v1/*path", proxyHandler.HandleProxy)
r.Any("/openai/*path", withProtocol("openai", proxyHandler.HandleProxy))
r.Any("/anthropic/*path", withProtocol("anthropic", proxyHandler.HandleProxy))
providers := r.Group("/api/providers")
{
@@ -199,12 +200,26 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
})
}
func withProtocol(protocol string, next gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
c.Params = append(c.Params, gin.Param{Key: "protocol", Value: protocol})
next(c)
}
}
func setupStaticFiles(r *gin.Engine) {
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
distFS, err := frontendDistFS()
if err != nil {
zapLogger.Fatal("无法加载前端资源", zap.Error(err))
}
setupStaticFilesWithFS(r, distFS)
}
func frontendDistFS() (fs.FS, error) {
return fs.Sub(embedfs.FrontendDist, "frontend-dist")
}
func setupStaticFilesWithFS(r *gin.Engine, distFS fs.FS) {
getContentType := func(path string) string {
if strings.HasSuffix(path, ".js") {
return "application/javascript"
@@ -250,7 +265,10 @@ func setupStaticFiles(r *gin.Engine) {
path := c.Request.URL.Path
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/openai/") ||
strings.HasPrefix(path, "/anthropic/") ||
path == "/openai" ||
path == "/anthropic" ||
strings.HasPrefix(path, "/health") {
c.JSON(404, gin.H{"error": "not found"})
return

View File

@@ -3,13 +3,59 @@
package main
import (
"errors"
"syscall"
"testing"
)
func TestMessageBoxW_WindowsOnly(t *testing.T) {
messageBox("测试标题", "测试消息", MB_ICONINFORMATION)
func withMessageBoxW(t *testing.T, fn func(hwnd, text, caption, flags uintptr) (uintptr, error)) {
t.Helper()
old := callMessageBoxW
callMessageBoxW = fn
t.Cleanup(func() {
callMessageBoxW = old
})
}
func TestMessageBoxW_WindowsOnly_InvalidUTF16(t *testing.T) {
err := messageBox("bad\x00title", "测试消息", mbIconInformation)
if err == nil {
t.Fatal("包含 NUL 字符时应该返回错误")
}
}
func TestMessageBoxW_WindowsOnly_SuccessIgnoresLastError(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 1, syscall.Errno(123)
})
if err := messageBox("测试标题", "测试消息", mbIconInformation); err != nil {
t.Fatalf("MessageBoxW 返回成功时应忽略 last error: %v", err)
}
}
func TestMessageBoxW_WindowsOnly_FailureUsesReturnValue(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 0, syscall.Errno(5)
})
err := messageBox("测试标题", "测试消息", mbIconInformation)
if !errors.Is(err, syscall.Errno(5)) {
t.Fatalf("MessageBoxW 返回 0 时应返回调用错误: %v", err)
}
}
func TestShowError_WindowsBranch(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 0, syscall.Errno(5)
})
defer func() {
if recovered := recover(); recovered != nil {
t.Fatalf("showError 不应因 MessageBoxW 失败而 panic: %v", recovered)
}
}()
showError("测试错误", "这是一条测试错误消息")
}

View File

@@ -4,5 +4,6 @@ const (
appName = "Nex"
appTooltip = appName
appDescription = "AI Gateway - 统一的大模型 API 网关"
appWebsite = "https://github.com/nex/gateway"
// #nosec G101 -- 项目官网地址不是凭据
appWebsite = "https://github.com/nex/gateway"
)

View File

@@ -22,7 +22,7 @@ func TestCheckPortAvailable(t *testing.T) {
func TestCheckPortOccupied(t *testing.T) {
port := 19827
listener, err := net.Listen("tcp", ":19827")
listener, err := net.Listen("tcp", ":19827") //nolint:gosec // 需要验证 checkPortAvailable 对通配地址占用的检测行为
if err != nil {
t.Fatalf("无法启动测试服务器: %v", err)
}

View File

@@ -1,73 +1,25 @@
package main
import (
"io/fs"
"net/http"
"net/http/httptest"
"strings"
"testing"
"nex/embedfs"
"github.com/gin-gonic/gin"
)
func TestSetupStaticFiles(t *testing.T) {
gin.SetMode(gin.TestMode)
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
distFS, err := frontendDistFS()
if err != nil {
t.Skipf("跳过测试: 前端资源未构建: %v", err)
return
}
getContentType := func(path string) string {
if strings.HasSuffix(path, ".js") {
return "application/javascript"
}
if strings.HasSuffix(path, ".css") {
return "text/css"
}
if strings.HasSuffix(path, ".svg") {
return "image/svg+xml"
}
return "application/octet-stream"
}
r := gin.New()
r.GET("/assets/*filepath", func(c *gin.Context) {
filepath := c.Param("filepath")
data, err := fs.ReadFile(distFS, "assets"+filepath)
if err != nil {
c.Status(404)
return
}
c.Data(200, getContentType(filepath), data)
})
r.GET("/favicon.svg", func(c *gin.Context) {
data, err := fs.ReadFile(distFS, "favicon.svg")
if err != nil {
c.Status(404)
return
}
c.Data(200, "image/svg+xml", data)
})
r.NoRoute(func(c *gin.Context) {
path := c.Request.URL.Path
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/health") {
c.JSON(404, gin.H{"error": "not found"})
return
}
data, err := fs.ReadFile(distFS, "index.html")
if err != nil {
c.Status(500)
return
}
c.Data(200, "text/html; charset=utf-8", data)
})
setupStaticFilesWithFS(r, distFS)
t.Run("API 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
@@ -79,6 +31,32 @@ func TestSetupStaticFiles(t *testing.T) {
}
})
t.Run("OpenAI proxy prefix 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/openai/", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("期望状态码 404, 实际 %d", w.Code)
}
if !strings.Contains(w.Body.String(), "not found") {
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
}
})
t.Run("Anthropic proxy prefix 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/anthropic/", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("期望状态码 404, 实际 %d", w.Code)
}
if !strings.Contains(w.Body.String(), "not found") {
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
}
})
t.Run("SPA fallback", func(t *testing.T) {
req := httptest.NewRequest("GET", "/providers", nil)
w := httptest.NewRecorder()
@@ -121,3 +99,115 @@ func TestSetupStaticFiles(t *testing.T) {
t.Log("静态文件服务测试通过")
}
func TestWithProtocolAndStaticRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
distFS, err := frontendDistFS()
if err != nil {
t.Skipf("跳过测试: 前端资源未构建: %v", err)
return
}
r := gin.New()
var gotProtocol string
var gotPath string
r.Any("/openai/*path", withProtocol("openai", func(c *gin.Context) {
gotProtocol = c.Param("protocol")
gotPath = c.Param("path")
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
}))
r.Any("/anthropic/*path", withProtocol("anthropic", func(c *gin.Context) {
gotProtocol = c.Param("protocol")
gotPath = c.Param("path")
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
}))
setupStaticFilesWithFS(r, distFS)
t.Run("OpenAI route enters proxy handler wrapper", func(t *testing.T) {
gotProtocol = ""
gotPath = ""
req := httptest.NewRequest("POST", "/openai/v1/chat/completions", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
if gotProtocol != "openai" {
t.Errorf("期望 protocol=openai, 实际 %s", gotProtocol)
}
if gotPath != "/v1/chat/completions" {
t.Errorf("期望 path=/v1/chat/completions, 实际 %s", gotPath)
}
})
t.Run("Anthropic route enters proxy handler wrapper", func(t *testing.T) {
gotProtocol = ""
gotPath = ""
req := httptest.NewRequest("POST", "/anthropic/v1/messages", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
if gotProtocol != "anthropic" {
t.Errorf("期望 protocol=anthropic, 实际 %s", gotProtocol)
}
if gotPath != "/v1/messages" {
t.Errorf("期望 path=/v1/messages, 实际 %s", gotPath)
}
})
t.Run("Static assets are not hijacked", func(t *testing.T) {
gotProtocol = ""
gotPath = ""
req := httptest.NewRequest("GET", "/assets/test.js", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if gotProtocol != "" || gotPath != "" {
t.Errorf("静态资源不应进入代理包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
}
if w.Code == http.StatusOK {
if !strings.HasPrefix(w.Header().Get("Content-Type"), "application/javascript") {
t.Errorf("期望 JS Content-Type, 实际 %s", w.Header().Get("Content-Type"))
}
return
}
if w.Code != http.StatusNotFound {
t.Errorf("期望静态资源返回 200 或 404, 实际 %d", w.Code)
}
})
t.Run("SPA path keeps fallback", func(t *testing.T) {
req := httptest.NewRequest("GET", "/providers", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
if !strings.Contains(w.Header().Get("Content-Type"), "text/html") {
t.Errorf("期望返回 HTML实际 %s", w.Header().Get("Content-Type"))
}
})
t.Run("Unknown proxy-like path does not return index html", func(t *testing.T) {
req := httptest.NewRequest("GET", "/openai/unknown", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("显式代理路由应进入代理包装器,实际状态码 %d", w.Code)
}
if gotProtocol != "openai" || gotPath != "/unknown" {
t.Errorf("期望 unknown 代理路径进入 openai 包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
}
})
}