1
0

fix: 修正 conversion 代理路径和错误边界

This commit is contained in:
2026-04-25 23:12:54 +08:00
parent f5c82b6980
commit 2c043c6cf7
25 changed files with 2020 additions and 214 deletions

View File

@@ -13,6 +13,7 @@ import (
"nex/backend/internal/domain"
"nex/backend/internal/provider"
"nex/backend/internal/service"
appErrors "nex/backend/pkg/errors"
"nex/backend/pkg/modelid"
"github.com/gin-gonic/gin"
@@ -48,7 +49,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
// 从 URL 提取 clientProtocol: /{protocol}/v1/...
clientProtocol := c.Param("protocol")
if clientProtocol == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"})
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "缺少协议前缀")
return
}
@@ -58,12 +59,13 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
path = "/" + path
}
nativePath := path
requestPath := appendRawQuery(nativePath, c.Request.URL.RawQuery)
// 获取 client adapter
registry := h.engine.GetRegistry()
clientAdapter, err := registry.Get(clientProtocol)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
return
}
@@ -80,7 +82,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
if ifaceType == conversion.InterfaceTypeModelInfo {
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的模型 ID 格式"})
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的模型 ID 格式")
return
}
h.handleModelInfo(c, unifiedID, clientAdapter)
@@ -90,40 +92,50 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "读取请求体失败")
return
}
// 解析统一模型 ID使用 adapter.ExtractModelName
var providerID, modelName string
if len(body) > 0 {
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
if err == nil && unifiedID != "" {
pid, mn, err := modelid.ParseUnifiedModelID(unifiedID)
if err == nil {
providerID = pid
modelName = mn
}
}
}
// 构建输入 HTTPRequestSpec
inSpec := conversion.HTTPRequestSpec{
URL: nativePath,
URL: requestPath,
Method: c.Request.Method,
Headers: extractHeaders(c),
Body: body,
}
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
// 只有 adapter 明确适配的接口才提取 model。未知接口不做通用 model 猜测。
if len(body) == 0 || !supportsModelExtraction(ifaceType) {
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
if err != nil {
if isInvalidJSONError(err) {
h.writeProxyError(c, http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误")
return
}
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
if err != nil {
// 原始模型名兼容透传:非统一模型 ID 不参与路由。
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
if providerID == "" || modelName == "" {
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
// 路由
routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
if err != nil {
// GET 请求或无法提取 model 时,直接转发到上游
if len(body) == 0 || modelName == "" {
h.forwardPassthrough(c, inSpec, clientProtocol)
return
}
h.writeError(c, err, clientProtocol)
h.writeRouteError(c, err)
return
}
@@ -143,9 +155,6 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
routeResult.Model.ModelName, // 上游模型名,用于请求改写
)
// 判断是否流式
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
// 计算统一模型 ID用于响应覆写
unifiedModelID := routeResult.Model.UnifiedModelID()
@@ -156,6 +165,28 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
}
}
func supportsModelExtraction(ifaceType conversion.InterfaceType) bool {
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
return true
default:
return false
}
}
func isInvalidJSONError(err error) bool {
var syntaxErr *json.SyntaxError
var typeErr *json.UnmarshalTypeError
return errors.As(err, &syntaxErr) || errors.As(err, &typeErr)
}
func appendRawQuery(path, rawQuery string) string {
if rawQuery == "" {
return path
}
return path + "?" + rawQuery
}
// handleNonStream 处理非流式请求
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
// 转换请求
@@ -170,7 +201,11 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil {
h.logger.Error("发送请求失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol)
h.writeUpstreamUnavailable(c, err)
return
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, *resp)
return
}
@@ -182,15 +217,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
return
}
// 设置响应头
for k, v := range convertedResp.Headers {
c.Header(k, v)
}
if c.GetHeader("Content-Type") == "" {
c.Header("Content-Type", "application/json")
}
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
h.writeConvertedResponse(c, *convertedResp)
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
@@ -206,15 +233,23 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
return
}
// 创建流式转换器,传入 modelOverride跨协议场景覆写 model 字段)
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
// 发送流式请求
streamResp, err := h.client.SendStream(c.Request.Context(), *outSpec)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
h.writeUpstreamUnavailable(c, err)
return
}
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
StatusCode: streamResp.StatusCode,
Headers: streamResp.Headers,
Body: streamResp.Body,
})
return
}
// 发送流式请求
eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec)
// 创建流式转换器,传入 modelOverride跨协议场景覆写 model 字段)
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
@@ -225,8 +260,9 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
c.Header("Connection", "keep-alive")
writer := bufio.NewWriter(c.Writer)
flushed := false
for event := range eventChan {
for event := range streamResp.Events {
if event.Error != nil {
h.logger.Error("流读取错误", zap.Error(event.Error))
break
@@ -237,6 +273,7 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("流式响应写回失败", zap.Error(err))
}
flushed = true
break
}
@@ -246,6 +283,12 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
break
}
}
if !flushed {
chunks := streamConverter.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("流式响应写回失败", zap.Error(err))
}
}
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
@@ -291,7 +334,7 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
models, err := h.providerService.ListEnabledModels()
if err != nil {
h.logger.Error("查询启用模型失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "查询模型失败")
return
}
@@ -313,7 +356,7 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
body, err := adapter.EncodeModelsResponse(modelList)
if err != nil {
h.logger.Error("编码 Models 响应失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
return
}
@@ -325,17 +368,14 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
// 解析统一模型 ID
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的统一模型 ID 格式",
"code": "INVALID_MODEL_ID",
})
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的统一模型 ID 格式")
return
}
// 从数据库查询模型
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "模型未找到"})
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", "模型未找到")
return
}
@@ -351,46 +391,103 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
body, err := adapter.EncodeModelInfoResponse(modelInfo)
if err != nil {
h.logger.Error("编码 ModelInfo 响应失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
return
}
c.Data(http.StatusOK, "application/json", body)
}
// writeConversionError 写入转换错误
// writeConversionError 写入网关层转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
var convErr *conversion.ConversionError
if errors.As(err, &convErr) {
body, statusCode, encodeErr := h.engine.EncodeError(convErr, clientProtocol)
if encodeErr != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": encodeErr.Error()})
return
}
c.Data(statusCode, "application/json", body)
statusCode, code, message := mapConversionError(convErr)
h.writeProxyError(c, statusCode, code, message)
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", err.Error())
}
// writeError 写入路由错误
func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
func mapConversionError(err *conversion.ConversionError) (int, string, string) {
switch err.Code {
case conversion.ErrorCodeJSONParseError:
if phase, ok := err.Details[conversion.ErrorDetailPhase].(string); ok && phase == conversion.ErrorPhaseRequest {
return http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误"
}
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
case conversion.ErrorCodeInvalidInput,
conversion.ErrorCodeMissingRequiredField,
conversion.ErrorCodeProtocolConstraint:
return http.StatusBadRequest, "INVALID_REQUEST", err.Message
case conversion.ErrorCodeInterfaceNotSupported:
return http.StatusBadRequest, "UNSUPPORTED_INTERFACE", err.Message
case conversion.ErrorCodeUnsupportedMultimodal:
return http.StatusBadRequest, "UNSUPPORTED_MULTIMODAL", err.Message
default:
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
}
}
func (h *ProxyHandler) writeRouteError(c *gin.Context, err error) {
if appErr, ok := appErrors.AsAppError(err); ok {
switch appErr.Code {
case appErrors.ErrModelNotFound.Code, appErrors.ErrModelDisabled.Code:
h.writeProxyError(c, appErr.HTTPStatus, "MODEL_NOT_FOUND", appErr.Message)
case appErrors.ErrProviderNotFound.Code, appErrors.ErrProviderDisabled.Code:
h.writeProxyError(c, appErr.HTTPStatus, "PROVIDER_NOT_FOUND", appErr.Message)
default:
h.writeProxyError(c, appErr.HTTPStatus, "INVALID_REQUEST", appErr.Message)
}
return
}
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", err.Error())
}
func (h *ProxyHandler) writeUpstreamUnavailable(c *gin.Context, err error) {
h.logger.Error("上游不可达", zap.Error(err))
h.writeProxyError(c, http.StatusBadGateway, "UPSTREAM_UNAVAILABLE", "上游服务不可达")
}
func (h *ProxyHandler) writeProxyError(c *gin.Context, status int, code, message string) {
c.JSON(status, gin.H{
"error": message,
"code": code,
})
}
func (h *ProxyHandler) writeConvertedResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
for k, v := range resp.Headers {
c.Header(k, v)
}
contentType := headerValue(resp.Headers, "Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, resp.Body)
}
func (h *ProxyHandler) writeUpstreamResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
for k, v := range filterHopByHopHeaders(resp.Headers) {
c.Header(k, v)
}
contentType := headerValue(resp.Headers, "Content-Type")
c.Data(resp.StatusCode, contentType, resp.Body)
}
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) {
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string, ifaceType conversion.InterfaceType, isStream bool) {
registry := h.engine.GetRegistry()
adapter, err := registry.Get(clientProtocol)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
return
}
providers, err := h.providerService.List()
if err != nil || len(providers) == 0 {
h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL))
c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"})
h.logger.Warn("无可用供应商转发请求", zap.String("path", inSpec.URL))
h.writeProxyError(c, http.StatusNotFound, "PROVIDER_NOT_FOUND", "没有可用的供应商。请先创建供应商和模型。")
return
}
@@ -400,19 +497,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
providerProtocol = "openai"
}
ifaceType := adapter.DetectInterfaceType(inSpec.URL)
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
var outSpec *conversion.HTTPRequestSpec
if clientProtocol == providerProtocol {
upstreamURL := p.BaseURL + inSpec.URL
upstreamPath := adapter.BuildUrl(stripRawQuery(inSpec.URL), ifaceType)
upstreamPath = appendRawQuery(upstreamPath, rawQueryFromPath(inSpec.URL))
headers := adapter.BuildHeaders(targetProvider)
if _, ok := headers["Content-Type"]; !ok {
headers["Content-Type"] = "application/json"
}
outSpec = &conversion.HTTPRequestSpec{
URL: upstreamURL,
URL: joinBaseURL(p.BaseURL, upstreamPath),
Method: inSpec.Method,
Headers: headers,
Body: inSpec.Body,
@@ -425,9 +521,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
}
}
if isStream {
h.forwardStream(c, *outSpec, clientProtocol, providerProtocol, ifaceType)
return
}
resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
h.writeUpstreamUnavailable(c, err)
return
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, *resp)
return
}
@@ -437,13 +542,111 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
return
}
for k, v := range convertedResp.Headers {
c.Header(k, v)
h.writeConvertedResponse(c, *convertedResp)
}
func (h *ProxyHandler) forwardStream(c *gin.Context, outSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, ifaceType conversion.InterfaceType) {
streamResp, err := h.client.SendStream(c.Request.Context(), outSpec)
if err != nil {
h.writeUpstreamUnavailable(c, err)
return
}
if c.GetHeader("Content-Type") == "" {
c.Header("Content-Type", "application/json")
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
StatusCode: streamResp.StatusCode,
Headers: streamResp.Headers,
Body: streamResp.Body,
})
return
}
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, "", ifaceType)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
writer := bufio.NewWriter(c.Writer)
flushed := false
for event := range streamResp.Events {
if event.Error != nil {
h.logger.Error("透传流读取错误", zap.Error(event.Error))
break
}
if event.Done {
chunks := streamConverter.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
}
flushed = true
break
}
chunks := streamConverter.ProcessChunk(event.Data)
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
break
}
}
if !flushed {
chunks := streamConverter.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
}
}
}
func stripRawQuery(path string) string {
pathOnly, _, _ := strings.Cut(path, "?")
return pathOnly
}
func rawQueryFromPath(path string) string {
_, rawQuery, found := strings.Cut(path, "?")
if !found {
return ""
}
return rawQuery
}
func joinBaseURL(baseURL, path string) string {
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
}
func headerValue(headers map[string]string, key string) string {
for k, v := range headers {
if strings.EqualFold(k, key) {
return v
}
}
return ""
}
func filterHopByHopHeaders(headers map[string]string) map[string]string {
if len(headers) == 0 {
return nil
}
hopByHop := map[string]struct{}{
"connection": {},
"transfer-encoding": {},
"keep-alive": {},
"proxy-authenticate": {},
"proxy-authorization": {},
"te": {},
"trailer": {},
"upgrade": {},
}
filtered := make(map[string]string, len(headers))
for k, v := range headers {
if _, skip := hopByHop[strings.ToLower(k)]; skip {
continue
}
filtered[k] = v
}
return filtered
}
// extractHeaders 从 Gin context 提取请求头

View File

@@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
@@ -73,7 +74,7 @@ func TestProxyHandler_HandleProxy_NonStreamSuccess(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
@@ -93,7 +94,7 @@ func TestProxyHandler_HandleProxy_NonStreamSuccess(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -109,9 +110,8 @@ func TestProxyHandler_HandleProxy_RoutingError_WithBody(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(nil, appErrors.ErrModelNotFound)
routingSvc.EXPECT().RouteByModelName("unknown", "model").Return(nil, appErrors.ErrModelNotFound)
providerSvc := mocks.NewMockProviderService(ctrl)
providerSvc.EXPECT().List().Return(nil, nil)
client := mocks.NewMockProviderClient(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
@@ -119,10 +119,11 @@ func TestProxyHandler_HandleProxy_RoutingError_WithBody(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"unknown","messages":[{"role":"user","content":"hi"}]}`)))
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 404, w.Code)
assert.Contains(t, w.Body.String(), "MODEL_NOT_FOUND")
}
func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
@@ -131,7 +132,7 @@ func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
@@ -146,10 +147,11 @@ func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
assert.Equal(t, 502, w.Code)
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
}
func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
@@ -158,7 +160,7 @@ func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
@@ -173,10 +175,11 @@ func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
assert.Equal(t, 502, w.Code)
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
}
func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
@@ -185,12 +188,12 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10)
go func() {
defer close(ch)
@@ -199,7 +202,7 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
ch <- provider.StreamEvent{Done: true}
}()
return ch, nil
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
@@ -209,12 +212,13 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
assert.Contains(t, w.Body.String(), "Hello")
assert.Contains(t, w.Body.String(), "p1/gpt-4")
}
func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
@@ -223,12 +227,12 @@ func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
return nil, context.DeadlineExceeded
})
providerSvc := mocks.NewMockProviderService(ctrl)
@@ -238,10 +242,11 @@ func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
assert.Equal(t, 502, w.Code)
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
}
func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) {
@@ -286,7 +291,7 @@ func TestProxyHandler_ForwardPassthrough_UnsupportedProtocol(t *testing.T) {
c.Request = httptest.NewRequest("GET", "/unknown/models", nil)
h.HandleProxy(c)
assert.Equal(t, 400, w.Code)
assert.Equal(t, 404, w.Code)
}
func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
@@ -329,7 +334,7 @@ func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
@@ -348,7 +353,7 @@ func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -371,6 +376,7 @@ func TestProxyHandler_WriteConversionError_NonConversionError(t *testing.T) {
h.writeConversionError(c, context.DeadlineExceeded, "openai")
assert.Equal(t, 500, w.Code)
assert.JSONEq(t, `{"error":"context deadline exceeded","code":"CONVERSION_FAILED"}`, w.Body.String())
}
func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
@@ -390,7 +396,40 @@ func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "bad request")
h.writeConversionError(c, convErr, "openai")
assert.Equal(t, 500, w.Code)
assert.Equal(t, 400, w.Code)
assert.JSONEq(t, `{"error":"bad request","code":"INVALID_REQUEST"}`, w.Body.String())
}
func TestProxyHandler_WriteConversionError_JSONPhase(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
client := mocks.NewMockProviderClient(ctrl)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
t.Run("request json parse error", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
h.writeConversionError(c, conversion.NewRequestJSONParseError("解码请求失败", context.Canceled), "openai")
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.JSONEq(t, `{"error":"请求体 JSON 格式错误","code":"INVALID_JSON"}`, w.Body.String())
})
t.Run("response json parse error", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
h.writeConversionError(c, conversion.NewResponseJSONParseError("解码响应失败", context.Canceled), "openai")
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.JSONEq(t, `{"error":"解码响应失败","code":"CONVERSION_FAILED"}`, w.Body.String())
})
}
func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) {
@@ -423,19 +462,19 @@ func TestProxyHandler_HandleStream_MidStreamError(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10)
go func() {
defer close(ch)
ch <- provider.StreamEvent{Data: []byte("data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n")}
ch <- provider.StreamEvent{Error: fmt.Errorf("connection reset by peer")}
}()
return ch, nil
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
@@ -445,7 +484,7 @@ func TestProxyHandler_HandleStream_MidStreamError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -460,12 +499,12 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10)
go func() {
defer close(ch)
@@ -473,7 +512,7 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
ch <- provider.StreamEvent{Done: true}
}()
return ch, nil
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
@@ -483,7 +522,7 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -505,7 +544,7 @@ func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) {
require.NoError(t, err)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
@@ -517,7 +556,7 @@ func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
@@ -532,7 +571,7 @@ func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) {
require.NoError(t, registry.Register(openai.NewAdapter()))
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
@@ -544,7 +583,7 @@ func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
@@ -560,7 +599,7 @@ func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) {
require.NoError(t, registry.Register(anthropic.NewAdapter()))
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "anthropic", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "claude-3", Enabled: true},
}, nil)
@@ -579,7 +618,7 @@ func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"claude-3","messages":[{"role":"user","content":"hi"}]}`)))
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
@@ -591,7 +630,7 @@ func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
@@ -611,7 +650,7 @@ func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -720,8 +759,9 @@ func TestProxyHandler_WriteError_RouteError(t *testing.T) {
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
h.writeError(c, fmt.Errorf("model not found"), "openai")
h.writeRouteError(c, fmt.Errorf("model not found"))
assert.Equal(t, 404, w.Code)
assert.Contains(t, w.Body.String(), "MODEL_NOT_FOUND")
}
func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
@@ -994,7 +1034,7 @@ func TestProxyHandler_HandleProxy_CrossProtocol_Stream_UnifiedID(t *testing.T) {
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10)
go func() {
defer close(ch)
@@ -1012,7 +1052,7 @@ data: {"type":"message_stop"}
`)}
ch <- provider.StreamEvent{Done: true}
}()
return ch, nil
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
@@ -1100,3 +1140,314 @@ func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Contains(t, resp, "error")
}
func TestProxyHandler_HandleProxy_OpenAIAndAnthropicNativePaths(t *testing.T) {
tests := []struct {
name string
protocol string
path string
requestPath string
baseURL string
expectedURL string
body string
responseBody string
responseModel string
}{
{
name: "openai path has no v1 after gateway prefix",
protocol: "openai",
path: "/chat/completions",
requestPath: "/openai/chat/completions",
baseURL: "https://api.test.com/v1",
expectedURL: "https://api.test.com/v1/chat/completions",
body: `{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`,
responseBody: `{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`,
responseModel: "p1/gpt-4",
},
{
name: "anthropic path keeps v1 after gateway prefix",
protocol: "anthropic",
path: "/v1/messages",
requestPath: "/anthropic/v1/messages",
baseURL: "https://api.anthropic.test",
expectedURL: "https://api.anthropic.test/v1/messages",
body: `{"model":"p1/gpt-4","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`,
responseBody: `{"id":"msg-1","type":"message","role":"assistant","model":"gpt-4","content":[{"type":"text","text":"ok"}],"stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`,
responseModel: "p1/gpt-4",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: tt.baseURL, Protocol: tt.protocol, Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
assert.Equal(t, tt.expectedURL, spec.URL)
return &conversion.HTTPResponseSpec{
StatusCode: http.StatusOK,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(tt.responseBody),
}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
statsSvc.EXPECT().Record(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: tt.protocol}, {Key: "path", Value: tt.path}}
c.Request = httptest.NewRequest("POST", tt.requestPath, bytes.NewReader([]byte(tt.body)))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), tt.responseModel)
})
}
}
func TestProxyHandler_UpstreamNon2xx_Passthrough(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).Return(&conversion.HTTPResponseSpec{
StatusCode: http.StatusTooManyRequests,
Headers: map[string]string{
"Content-Type": "application/json",
"X-Upstream-Error": "rate-limit",
"Transfer-Encoding": "chunked",
},
Body: []byte(`{"error":{"message":"rate limited"}}`),
}, nil)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusTooManyRequests, w.Code)
assert.JSONEq(t, `{"error":{"message":"rate limited"}}`, w.Body.String())
assert.Equal(t, "rate-limit", w.Header().Get("X-Upstream-Error"))
assert.Empty(t, w.Header().Get("Transfer-Encoding"))
}
func TestProxyHandler_StreamUpstreamNon2xx_Passthrough(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).Return(&provider.StreamResponse{
StatusCode: http.StatusServiceUnavailable,
Headers: map[string]string{"Content-Type": "application/json", "Connection": "close"},
Body: []byte(`{"error":"upstream down"}`),
}, nil)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusServiceUnavailable, w.Code)
assert.JSONEq(t, `{"error":"upstream down"}`, w.Body.String())
assert.Empty(t, w.Header().Get("Connection"))
}
func TestFilterHopByHopHeaders(t *testing.T) {
filtered := filterHopByHopHeaders(map[string]string{
"Connection": "close",
"Transfer-Encoding": "chunked",
"Keep-Alive": "timeout=5",
"Proxy-Authenticate": "Basic",
"Proxy-Authorization": "Basic token",
"TE": "trailers",
"Trailer": "Expires",
"Upgrade": "websocket",
"Content-Type": "application/json",
"X-Request-ID": "req-1",
})
assert.Equal(t, map[string]string{
"Content-Type": "application/json",
"X-Request-ID": "req-1",
}, filtered)
}
func TestProxyHandler_UnknownInterface_DoesNotGuessModel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
providerSvc.EXPECT().List().Return([]domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
assert.Equal(t, "https://api.test.com/v1/unknown?trace=1", spec.URL)
assert.JSONEq(t, `{"model":"p1/gpt-4","payload":true}`, string(spec.Body))
return &conversion.HTTPResponseSpec{
StatusCode: http.StatusOK,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"ok":true}`),
}, nil
})
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/unknown"}}
c.Request = httptest.NewRequest("POST", "/openai/unknown?trace=1", bytes.NewReader([]byte(`{"model":"p1/gpt-4","payload":true}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.JSONEq(t, `{"ok":true}`, w.Body.String())
}
func TestProxyHandler_InvalidJSON_UsesGatewayError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
client := mocks.NewMockProviderClient(ctrl)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":`)))
h.HandleProxy(c)
require.Equal(t, http.StatusBadRequest, w.Code)
assert.JSONEq(t, `{"error":"请求体 JSON 格式错误","code":"INVALID_JSON"}`, w.Body.String())
}
func TestProxyHandler_CrossProtocolMultimodal_Unsupported(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("anthropic_p", "claude").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.test", Protocol: "anthropic", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
body := []byte(`{"model":"anthropic_p/claude","messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
h.HandleProxy(c)
require.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "UNSUPPORTED_MULTIMODAL")
}
func TestProxyHandler_SameProtocolMultimodal_SmartPassthrough(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
assert.Contains(t, string(spec.Body), "image_url")
assert.Contains(t, string(spec.Body), `"model":"gpt-4"`)
return &conversion.HTTPResponseSpec{
StatusCode: http.StatusOK,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`),
}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
statsSvc.EXPECT().Record(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
body := []byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "p1/gpt-4")
}
func TestProxyHandler_RawStreamPassthrough_PreservesSSEFrames(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
providerSvc.EXPECT().List().Return([]domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
assert.Contains(t, string(spec.Body), `"model":"gpt-4"`)
ch := make(chan provider.StreamEvent, 3)
go func() {
defer close(ch)
ch <- provider.StreamEvent{Data: []byte("data: {\"model\":\"gpt-4\",\"choices\":[]}\n\n")}
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
ch <- provider.StreamEvent{Done: true}
}()
return &provider.StreamResponse{StatusCode: http.StatusOK, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
})
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "data: {\"model\":\"gpt-4\",\"choices\":[]}\n\ndata: [DONE]\n\n", w.Body.String())
}