fix: 完善转换代理行为
This commit is contained in:
@@ -3,31 +3,60 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
MB_ICONERROR = 0x10
|
||||
MB_ICONINFORMATION = 0x40
|
||||
mbIconError = 0x10
|
||||
mbIconInformation = 0x40
|
||||
)
|
||||
|
||||
var (
|
||||
user32 = syscall.NewLazyDLL("user32.dll")
|
||||
procMessageBoxW = user32.NewProc("MessageBoxW")
|
||||
callMessageBoxW = func(hwnd, text, caption, flags uintptr) (uintptr, error) {
|
||||
ret, _, err := procMessageBoxW.Call(hwnd, text, caption, flags)
|
||||
return ret, err
|
||||
}
|
||||
)
|
||||
|
||||
func showError(title, message string) {
|
||||
messageBox(title, message, MB_ICONERROR)
|
||||
if err := messageBox(title, message, mbIconError); err != nil {
|
||||
if zapLogger != nil {
|
||||
zapLogger.Warn("显示错误对话框失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func messageBox(title, message string, flags uint) {
|
||||
titlePtr, _ := syscall.UTF16PtrFromString(title)
|
||||
messagePtr, _ := syscall.UTF16PtrFromString(message)
|
||||
procMessageBoxW.Call(
|
||||
func messageBox(title, message string, flags uint) error {
|
||||
titlePtr, err := syscall.UTF16PtrFromString(title)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
messagePtr, err := syscall.UTF16PtrFromString(message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ret, callErr := callMessageBoxW(
|
||||
0,
|
||||
uintptr(unsafe.Pointer(messagePtr)),
|
||||
uintptr(unsafe.Pointer(titlePtr)),
|
||||
uintptr(flags),
|
||||
)
|
||||
if ret != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if callErr != nil && !errors.Is(callErr, syscall.Errno(0)) {
|
||||
return callErr
|
||||
}
|
||||
|
||||
return fmt.Errorf("MessageBoxW 调用失败")
|
||||
}
|
||||
|
||||
@@ -168,7 +168,8 @@ func main() {
|
||||
}
|
||||
|
||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
|
||||
r.Any("/v1/*path", proxyHandler.HandleProxy)
|
||||
r.Any("/openai/*path", withProtocol("openai", proxyHandler.HandleProxy))
|
||||
r.Any("/anthropic/*path", withProtocol("anthropic", proxyHandler.HandleProxy))
|
||||
|
||||
providers := r.Group("/api/providers")
|
||||
{
|
||||
@@ -199,12 +200,26 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
|
||||
})
|
||||
}
|
||||
|
||||
func withProtocol(protocol string, next gin.HandlerFunc) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Params = append(c.Params, gin.Param{Key: "protocol", Value: protocol})
|
||||
next(c)
|
||||
}
|
||||
}
|
||||
|
||||
func setupStaticFiles(r *gin.Engine) {
|
||||
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
||||
distFS, err := frontendDistFS()
|
||||
if err != nil {
|
||||
zapLogger.Fatal("无法加载前端资源", zap.Error(err))
|
||||
}
|
||||
setupStaticFilesWithFS(r, distFS)
|
||||
}
|
||||
|
||||
func frontendDistFS() (fs.FS, error) {
|
||||
return fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
||||
}
|
||||
|
||||
func setupStaticFilesWithFS(r *gin.Engine, distFS fs.FS) {
|
||||
getContentType := func(path string) string {
|
||||
if strings.HasSuffix(path, ".js") {
|
||||
return "application/javascript"
|
||||
@@ -250,7 +265,10 @@ func setupStaticFiles(r *gin.Engine) {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/openai/") ||
|
||||
strings.HasPrefix(path, "/anthropic/") ||
|
||||
path == "/openai" ||
|
||||
path == "/anthropic" ||
|
||||
strings.HasPrefix(path, "/health") {
|
||||
c.JSON(404, gin.H{"error": "not found"})
|
||||
return
|
||||
|
||||
@@ -3,13 +3,59 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"syscall"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMessageBoxW_WindowsOnly(t *testing.T) {
|
||||
messageBox("测试标题", "测试消息", MB_ICONINFORMATION)
|
||||
func withMessageBoxW(t *testing.T, fn func(hwnd, text, caption, flags uintptr) (uintptr, error)) {
|
||||
t.Helper()
|
||||
|
||||
old := callMessageBoxW
|
||||
callMessageBoxW = fn
|
||||
t.Cleanup(func() {
|
||||
callMessageBoxW = old
|
||||
})
|
||||
}
|
||||
|
||||
func TestMessageBoxW_WindowsOnly_InvalidUTF16(t *testing.T) {
|
||||
err := messageBox("bad\x00title", "测试消息", mbIconInformation)
|
||||
if err == nil {
|
||||
t.Fatal("包含 NUL 字符时应该返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageBoxW_WindowsOnly_SuccessIgnoresLastError(t *testing.T) {
|
||||
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
|
||||
return 1, syscall.Errno(123)
|
||||
})
|
||||
|
||||
if err := messageBox("测试标题", "测试消息", mbIconInformation); err != nil {
|
||||
t.Fatalf("MessageBoxW 返回成功时应忽略 last error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageBoxW_WindowsOnly_FailureUsesReturnValue(t *testing.T) {
|
||||
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
|
||||
return 0, syscall.Errno(5)
|
||||
})
|
||||
|
||||
err := messageBox("测试标题", "测试消息", mbIconInformation)
|
||||
if !errors.Is(err, syscall.Errno(5)) {
|
||||
t.Fatalf("MessageBoxW 返回 0 时应返回调用错误: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowError_WindowsBranch(t *testing.T) {
|
||||
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
|
||||
return 0, syscall.Errno(5)
|
||||
})
|
||||
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
t.Fatalf("showError 不应因 MessageBoxW 失败而 panic: %v", recovered)
|
||||
}
|
||||
}()
|
||||
|
||||
showError("测试错误", "这是一条测试错误消息")
|
||||
}
|
||||
|
||||
@@ -4,5 +4,6 @@ const (
|
||||
appName = "Nex"
|
||||
appTooltip = appName
|
||||
appDescription = "AI Gateway - 统一的大模型 API 网关"
|
||||
appWebsite = "https://github.com/nex/gateway"
|
||||
// #nosec G101 -- 项目官网地址不是凭据
|
||||
appWebsite = "https://github.com/nex/gateway"
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestCheckPortAvailable(t *testing.T) {
|
||||
func TestCheckPortOccupied(t *testing.T) {
|
||||
port := 19827
|
||||
|
||||
listener, err := net.Listen("tcp", ":19827")
|
||||
listener, err := net.Listen("tcp", ":19827") //nolint:gosec // 需要验证 checkPortAvailable 对通配地址占用的检测行为
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
|
||||
@@ -1,73 +1,25 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nex/embedfs"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestSetupStaticFiles(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
||||
distFS, err := frontendDistFS()
|
||||
if err != nil {
|
||||
t.Skipf("跳过测试: 前端资源未构建: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
getContentType := func(path string) string {
|
||||
if strings.HasSuffix(path, ".js") {
|
||||
return "application/javascript"
|
||||
}
|
||||
if strings.HasSuffix(path, ".css") {
|
||||
return "text/css"
|
||||
}
|
||||
if strings.HasSuffix(path, ".svg") {
|
||||
return "image/svg+xml"
|
||||
}
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/assets/*filepath", func(c *gin.Context) {
|
||||
filepath := c.Param("filepath")
|
||||
data, err := fs.ReadFile(distFS, "assets"+filepath)
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, getContentType(filepath), data)
|
||||
})
|
||||
|
||||
r.GET("/favicon.svg", func(c *gin.Context) {
|
||||
data, err := fs.ReadFile(distFS, "favicon.svg")
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, "image/svg+xml", data)
|
||||
})
|
||||
|
||||
r.NoRoute(func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/health") {
|
||||
c.JSON(404, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
data, err := fs.ReadFile(distFS, "index.html")
|
||||
if err != nil {
|
||||
c.Status(500)
|
||||
return
|
||||
}
|
||||
c.Data(200, "text/html; charset=utf-8", data)
|
||||
})
|
||||
setupStaticFilesWithFS(r, distFS)
|
||||
|
||||
t.Run("API 404", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
@@ -79,6 +31,32 @@ func TestSetupStaticFiles(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OpenAI proxy prefix 404", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/openai/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望状态码 404, 实际 %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "not found") {
|
||||
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Anthropic proxy prefix 404", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/anthropic/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望状态码 404, 实际 %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "not found") {
|
||||
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SPA fallback", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/providers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -121,3 +99,115 @@ func TestSetupStaticFiles(t *testing.T) {
|
||||
|
||||
t.Log("静态文件服务测试通过")
|
||||
}
|
||||
|
||||
func TestWithProtocolAndStaticRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
distFS, err := frontendDistFS()
|
||||
if err != nil {
|
||||
t.Skipf("跳过测试: 前端资源未构建: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
|
||||
var gotProtocol string
|
||||
var gotPath string
|
||||
r.Any("/openai/*path", withProtocol("openai", func(c *gin.Context) {
|
||||
gotProtocol = c.Param("protocol")
|
||||
gotPath = c.Param("path")
|
||||
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
|
||||
}))
|
||||
r.Any("/anthropic/*path", withProtocol("anthropic", func(c *gin.Context) {
|
||||
gotProtocol = c.Param("protocol")
|
||||
gotPath = c.Param("path")
|
||||
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
|
||||
}))
|
||||
setupStaticFilesWithFS(r, distFS)
|
||||
|
||||
t.Run("OpenAI route enters proxy handler wrapper", func(t *testing.T) {
|
||||
gotProtocol = ""
|
||||
gotPath = ""
|
||||
|
||||
req := httptest.NewRequest("POST", "/openai/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
if gotProtocol != "openai" {
|
||||
t.Errorf("期望 protocol=openai, 实际 %s", gotProtocol)
|
||||
}
|
||||
if gotPath != "/v1/chat/completions" {
|
||||
t.Errorf("期望 path=/v1/chat/completions, 实际 %s", gotPath)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Anthropic route enters proxy handler wrapper", func(t *testing.T) {
|
||||
gotProtocol = ""
|
||||
gotPath = ""
|
||||
|
||||
req := httptest.NewRequest("POST", "/anthropic/v1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
if gotProtocol != "anthropic" {
|
||||
t.Errorf("期望 protocol=anthropic, 实际 %s", gotProtocol)
|
||||
}
|
||||
if gotPath != "/v1/messages" {
|
||||
t.Errorf("期望 path=/v1/messages, 实际 %s", gotPath)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Static assets are not hijacked", func(t *testing.T) {
|
||||
gotProtocol = ""
|
||||
gotPath = ""
|
||||
|
||||
req := httptest.NewRequest("GET", "/assets/test.js", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if gotProtocol != "" || gotPath != "" {
|
||||
t.Errorf("静态资源不应进入代理包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
|
||||
}
|
||||
if w.Code == http.StatusOK {
|
||||
if !strings.HasPrefix(w.Header().Get("Content-Type"), "application/javascript") {
|
||||
t.Errorf("期望 JS Content-Type, 实际 %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
return
|
||||
}
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望静态资源返回 200 或 404, 实际 %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SPA path keeps fallback", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/providers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Header().Get("Content-Type"), "text/html") {
|
||||
t.Errorf("期望返回 HTML,实际 %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Unknown proxy-like path does not return index html", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/openai/unknown", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("显式代理路由应进入代理包装器,实际状态码 %d", w.Code)
|
||||
}
|
||||
if gotProtocol != "openai" || gotPath != "/unknown" {
|
||||
t.Errorf("期望 unknown 代理路径进入 openai 包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user