662 lines
20 KiB
Go
662 lines
20 KiB
Go
package handler
|
||
|
||
import (
|
||
"bufio"
|
||
"encoding/json"
|
||
"errors"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"nex/backend/internal/conversion"
|
||
"nex/backend/internal/conversion/canonical"
|
||
"nex/backend/internal/domain"
|
||
"nex/backend/internal/provider"
|
||
"nex/backend/internal/service"
|
||
appErrors "nex/backend/pkg/errors"
|
||
"nex/backend/pkg/modelid"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"go.uber.org/zap"
|
||
|
||
pkglogger "nex/backend/pkg/logger"
|
||
)
|
||
|
||
// ProxyHandler 统一代理处理器
|
||
type ProxyHandler struct {
|
||
engine *conversion.ConversionEngine
|
||
client provider.ProviderClient
|
||
routingService service.RoutingService
|
||
providerService service.ProviderService
|
||
statsService service.StatsService
|
||
logger *zap.Logger
|
||
}
|
||
|
||
// NewProxyHandler 创建统一代理处理器
|
||
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService, logger *zap.Logger) *ProxyHandler {
|
||
return &ProxyHandler{
|
||
engine: engine,
|
||
client: client,
|
||
routingService: routingService,
|
||
providerService: providerService,
|
||
statsService: statsService,
|
||
logger: pkglogger.WithModule(logger, "handler.proxy"),
|
||
}
|
||
}
|
||
|
||
// HandleProxy 处理代理请求
|
||
func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||
// 从 URL 提取 clientProtocol: /{protocol}/v1/...
|
||
clientProtocol := c.Param("protocol")
|
||
if clientProtocol == "" {
|
||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "缺少协议前缀")
|
||
return
|
||
}
|
||
|
||
// 原始路径: /{path}
|
||
path := c.Param("path")
|
||
if !strings.HasPrefix(path, "/") {
|
||
path = "/" + path
|
||
}
|
||
nativePath := path
|
||
requestPath := appendRawQuery(nativePath, c.Request.URL.RawQuery)
|
||
|
||
// 获取 client adapter
|
||
registry := h.engine.GetRegistry()
|
||
clientAdapter, err := registry.Get(clientProtocol)
|
||
if err != nil {
|
||
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
|
||
return
|
||
}
|
||
|
||
// 检测接口类型
|
||
ifaceType := clientAdapter.DetectInterfaceType(nativePath)
|
||
|
||
// 处理 Models 接口:本地聚合
|
||
if ifaceType == conversion.InterfaceTypeModels {
|
||
h.handleModelsList(c, clientAdapter)
|
||
return
|
||
}
|
||
|
||
// 处理 ModelInfo 接口:本地查询
|
||
if ifaceType == conversion.InterfaceTypeModelInfo {
|
||
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
|
||
if err != nil {
|
||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的模型 ID 格式")
|
||
return
|
||
}
|
||
h.handleModelInfo(c, unifiedID, clientAdapter)
|
||
return
|
||
}
|
||
|
||
// 读取请求体
|
||
body, err := io.ReadAll(c.Request.Body)
|
||
if err != nil {
|
||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "读取请求体失败")
|
||
return
|
||
}
|
||
|
||
// 构建输入 HTTPRequestSpec
|
||
inSpec := conversion.HTTPRequestSpec{
|
||
URL: requestPath,
|
||
Method: c.Request.Method,
|
||
Headers: extractHeaders(c),
|
||
Body: body,
|
||
}
|
||
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
|
||
|
||
// 只有 adapter 明确适配的接口才提取 model。未知接口不做通用 model 猜测。
|
||
if len(body) == 0 || !supportsModelExtraction(ifaceType) {
|
||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||
return
|
||
}
|
||
|
||
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
|
||
if err != nil {
|
||
if isInvalidJSONError(err) {
|
||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误")
|
||
return
|
||
}
|
||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||
return
|
||
}
|
||
|
||
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
|
||
if err != nil {
|
||
// 原始模型名兼容透传:非统一模型 ID 不参与路由。
|
||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||
return
|
||
}
|
||
if providerID == "" || modelName == "" {
|
||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||
return
|
||
}
|
||
|
||
// 路由
|
||
routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
|
||
if err != nil {
|
||
h.writeRouteError(c, err)
|
||
return
|
||
}
|
||
|
||
// 确定 providerProtocol
|
||
providerProtocol := routeResult.Provider.Protocol
|
||
if providerProtocol == "" {
|
||
providerProtocol = "openai"
|
||
}
|
||
|
||
// 构建 TargetProvider
|
||
// 注意:ModelName 字段用于 Smart Passthrough 场景改写请求体
|
||
// 同协议:请求体中的统一 ID 会被改写为 ModelName(上游名)
|
||
// 跨协议:全量转换时 ModelName 会被编码到请求体中
|
||
targetProvider := conversion.NewTargetProvider(
|
||
routeResult.Provider.BaseURL,
|
||
routeResult.Provider.APIKey,
|
||
routeResult.Model.ModelName, // 上游模型名,用于请求改写
|
||
)
|
||
|
||
// 计算统一模型 ID(用于响应覆写)
|
||
unifiedModelID := routeResult.Model.UnifiedModelID()
|
||
|
||
if isStream {
|
||
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
|
||
} else {
|
||
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
|
||
}
|
||
}
|
||
|
||
func supportsModelExtraction(ifaceType conversion.InterfaceType) bool {
|
||
switch ifaceType {
|
||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func isInvalidJSONError(err error) bool {
|
||
var syntaxErr *json.SyntaxError
|
||
var typeErr *json.UnmarshalTypeError
|
||
return errors.As(err, &syntaxErr) || errors.As(err, &typeErr)
|
||
}
|
||
|
||
func appendRawQuery(path, rawQuery string) string {
|
||
if rawQuery == "" {
|
||
return path
|
||
}
|
||
return path + "?" + rawQuery
|
||
}
|
||
|
||
// handleNonStream 处理非流式请求
|
||
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
|
||
// 转换请求
|
||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||
if err != nil {
|
||
h.logger.Error("转换请求失败", zap.Error(err))
|
||
h.writeConversionError(c, err, clientProtocol)
|
||
return
|
||
}
|
||
|
||
// 发送请求
|
||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||
if err != nil {
|
||
h.logger.Error("发送请求失败", zap.Error(err))
|
||
h.writeUpstreamUnavailable(c, err)
|
||
return
|
||
}
|
||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||
h.writeUpstreamResponse(c, *resp)
|
||
return
|
||
}
|
||
|
||
// 转换响应,传入 modelOverride(跨协议场景覆写 model 字段)
|
||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
|
||
if err != nil {
|
||
h.logger.Error("转换响应失败", zap.Error(err))
|
||
h.writeConversionError(c, err, clientProtocol)
|
||
return
|
||
}
|
||
|
||
h.writeConvertedResponse(c, *convertedResp)
|
||
|
||
go func() {
|
||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||
}()
|
||
}
|
||
|
||
// handleStream 处理流式请求
|
||
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
|
||
// 转换请求
|
||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||
if err != nil {
|
||
h.writeConversionError(c, err, clientProtocol)
|
||
return
|
||
}
|
||
|
||
// 发送流式请求
|
||
streamResp, err := h.client.SendStream(c.Request.Context(), *outSpec)
|
||
if err != nil {
|
||
h.writeUpstreamUnavailable(c, err)
|
||
return
|
||
}
|
||
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
|
||
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
|
||
StatusCode: streamResp.StatusCode,
|
||
Headers: streamResp.Headers,
|
||
Body: streamResp.Body,
|
||
})
|
||
return
|
||
}
|
||
|
||
// 创建流式转换器,传入 modelOverride(跨协议场景覆写 model 字段)
|
||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
|
||
if err != nil {
|
||
h.writeConversionError(c, err, clientProtocol)
|
||
return
|
||
}
|
||
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
|
||
writer := bufio.NewWriter(c.Writer)
|
||
flushed := false
|
||
|
||
for event := range streamResp.Events {
|
||
if event.Error != nil {
|
||
h.logger.Error("流读取错误", zap.Error(event.Error))
|
||
break
|
||
}
|
||
if event.Done {
|
||
// flush 转换器
|
||
chunks := streamConverter.Flush()
|
||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||
}
|
||
flushed = true
|
||
break
|
||
}
|
||
|
||
chunks := streamConverter.ProcessChunk(event.Data)
|
||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||
break
|
||
}
|
||
}
|
||
if !flushed {
|
||
chunks := streamConverter.Flush()
|
||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||
}
|
||
}
|
||
|
||
go func() {
|
||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||
}()
|
||
}
|
||
|
||
func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error {
|
||
for _, chunk := range chunks {
|
||
if _, err := writer.Write(chunk); err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := writer.Flush(); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// isStreamRequest 判断是否流式请求
|
||
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
|
||
ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
if ifaceType != conversion.InterfaceTypeChat {
|
||
return false
|
||
}
|
||
|
||
var req struct {
|
||
Stream bool `json:"stream"`
|
||
}
|
||
if err := json.Unmarshal(body, &req); err != nil {
|
||
return false
|
||
}
|
||
return req.Stream
|
||
}
|
||
|
||
// handleModelsList 处理 GET /v1/models 本地聚合
|
||
func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.ProtocolAdapter) {
|
||
// 从数据库查询所有启用的模型
|
||
models, err := h.providerService.ListEnabledModels()
|
||
if err != nil {
|
||
h.logger.Error("查询启用模型失败", zap.Error(err))
|
||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "查询模型失败")
|
||
return
|
||
}
|
||
|
||
// 构建 CanonicalModelList
|
||
modelList := &canonical.CanonicalModelList{
|
||
Models: make([]canonical.CanonicalModel, 0, len(models)),
|
||
}
|
||
|
||
for _, m := range models {
|
||
modelList.Models = append(modelList.Models, canonical.CanonicalModel{
|
||
ID: m.UnifiedModelID(),
|
||
Name: m.ModelName,
|
||
Created: m.CreatedAt.Unix(),
|
||
OwnedBy: m.ProviderID,
|
||
})
|
||
}
|
||
|
||
// 使用 adapter 编码返回
|
||
body, err := adapter.EncodeModelsResponse(modelList)
|
||
if err != nil {
|
||
h.logger.Error("编码 Models 响应失败", zap.Error(err))
|
||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
|
||
return
|
||
}
|
||
|
||
c.Data(http.StatusOK, "application/json", body)
|
||
}
|
||
|
||
// handleModelInfo 处理 GET /v1/models/{unified_id} 本地查询
|
||
func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter conversion.ProtocolAdapter) {
|
||
// 解析统一模型 ID
|
||
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
|
||
if err != nil {
|
||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的统一模型 ID 格式")
|
||
return
|
||
}
|
||
|
||
// 从数据库查询模型
|
||
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
|
||
if err != nil {
|
||
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", "模型未找到")
|
||
return
|
||
}
|
||
|
||
// 构建 CanonicalModelInfo
|
||
modelInfo := &canonical.CanonicalModelInfo{
|
||
ID: model.UnifiedModelID(),
|
||
Name: model.ModelName,
|
||
Created: model.CreatedAt.Unix(),
|
||
OwnedBy: model.ProviderID,
|
||
}
|
||
|
||
// 使用 adapter 编码返回
|
||
body, err := adapter.EncodeModelInfoResponse(modelInfo)
|
||
if err != nil {
|
||
h.logger.Error("编码 ModelInfo 响应失败", zap.Error(err))
|
||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
|
||
return
|
||
}
|
||
|
||
c.Data(http.StatusOK, "application/json", body)
|
||
}
|
||
|
||
// writeConversionError 写入网关层转换错误
|
||
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
|
||
var convErr *conversion.ConversionError
|
||
if errors.As(err, &convErr) {
|
||
statusCode, code, message := mapConversionError(convErr)
|
||
h.writeProxyError(c, statusCode, code, message)
|
||
return
|
||
}
|
||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", err.Error())
|
||
}
|
||
|
||
func mapConversionError(err *conversion.ConversionError) (int, string, string) {
|
||
switch err.Code {
|
||
case conversion.ErrorCodeJSONParseError:
|
||
if phase, ok := err.Details[conversion.ErrorDetailPhase].(string); ok && phase == conversion.ErrorPhaseRequest {
|
||
return http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误"
|
||
}
|
||
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
|
||
case conversion.ErrorCodeInvalidInput,
|
||
conversion.ErrorCodeMissingRequiredField,
|
||
conversion.ErrorCodeProtocolConstraint:
|
||
return http.StatusBadRequest, "INVALID_REQUEST", err.Message
|
||
case conversion.ErrorCodeInterfaceNotSupported:
|
||
return http.StatusBadRequest, "UNSUPPORTED_INTERFACE", err.Message
|
||
case conversion.ErrorCodeUnsupportedMultimodal:
|
||
return http.StatusBadRequest, "UNSUPPORTED_MULTIMODAL", err.Message
|
||
default:
|
||
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
|
||
}
|
||
}
|
||
|
||
func (h *ProxyHandler) writeRouteError(c *gin.Context, err error) {
|
||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||
switch appErr.Code {
|
||
case appErrors.ErrModelNotFound.Code, appErrors.ErrModelDisabled.Code:
|
||
h.writeProxyError(c, appErr.HTTPStatus, "MODEL_NOT_FOUND", appErr.Message)
|
||
case appErrors.ErrProviderNotFound.Code, appErrors.ErrProviderDisabled.Code:
|
||
h.writeProxyError(c, appErr.HTTPStatus, "PROVIDER_NOT_FOUND", appErr.Message)
|
||
default:
|
||
h.writeProxyError(c, appErr.HTTPStatus, "INVALID_REQUEST", appErr.Message)
|
||
}
|
||
return
|
||
}
|
||
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", err.Error())
|
||
}
|
||
|
||
func (h *ProxyHandler) writeUpstreamUnavailable(c *gin.Context, err error) {
|
||
h.logger.Error("上游不可达", zap.Error(err))
|
||
h.writeProxyError(c, http.StatusBadGateway, "UPSTREAM_UNAVAILABLE", "上游服务不可达")
|
||
}
|
||
|
||
func (h *ProxyHandler) writeProxyError(c *gin.Context, status int, code, message string) {
|
||
c.JSON(status, gin.H{
|
||
"error": message,
|
||
"code": code,
|
||
})
|
||
}
|
||
|
||
func (h *ProxyHandler) writeConvertedResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
|
||
for k, v := range resp.Headers {
|
||
c.Header(k, v)
|
||
}
|
||
contentType := headerValue(resp.Headers, "Content-Type")
|
||
if contentType == "" {
|
||
contentType = "application/json"
|
||
}
|
||
c.Data(resp.StatusCode, contentType, resp.Body)
|
||
}
|
||
|
||
func (h *ProxyHandler) writeUpstreamResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
|
||
for k, v := range filterHopByHopHeaders(resp.Headers) {
|
||
c.Header(k, v)
|
||
}
|
||
contentType := headerValue(resp.Headers, "Content-Type")
|
||
c.Data(resp.StatusCode, contentType, resp.Body)
|
||
}
|
||
|
||
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
|
||
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string, ifaceType conversion.InterfaceType, isStream bool) {
|
||
registry := h.engine.GetRegistry()
|
||
adapter, err := registry.Get(clientProtocol)
|
||
if err != nil {
|
||
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
|
||
return
|
||
}
|
||
|
||
providers, err := h.providerService.List()
|
||
if err != nil || len(providers) == 0 {
|
||
h.logger.Warn("无可用供应商转发请求", zap.String("path", inSpec.URL))
|
||
h.writeProxyError(c, http.StatusNotFound, "PROVIDER_NOT_FOUND", "没有可用的供应商。请先创建供应商和模型。")
|
||
return
|
||
}
|
||
|
||
p := providers[0]
|
||
providerProtocol := p.Protocol
|
||
if providerProtocol == "" {
|
||
providerProtocol = "openai"
|
||
}
|
||
|
||
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
|
||
|
||
var outSpec *conversion.HTTPRequestSpec
|
||
if clientProtocol == providerProtocol {
|
||
upstreamPath := adapter.BuildUrl(stripRawQuery(inSpec.URL), ifaceType)
|
||
upstreamPath = appendRawQuery(upstreamPath, rawQueryFromPath(inSpec.URL))
|
||
headers := adapter.BuildHeaders(targetProvider)
|
||
if _, ok := headers["Content-Type"]; !ok {
|
||
headers["Content-Type"] = "application/json"
|
||
}
|
||
outSpec = &conversion.HTTPRequestSpec{
|
||
URL: joinBaseURL(p.BaseURL, upstreamPath),
|
||
Method: inSpec.Method,
|
||
Headers: headers,
|
||
Body: inSpec.Body,
|
||
}
|
||
} else {
|
||
outSpec, err = h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||
if err != nil {
|
||
h.writeConversionError(c, err, clientProtocol)
|
||
return
|
||
}
|
||
}
|
||
|
||
if isStream {
|
||
h.forwardStream(c, *outSpec, clientProtocol, providerProtocol, ifaceType)
|
||
return
|
||
}
|
||
|
||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||
if err != nil {
|
||
h.writeUpstreamUnavailable(c, err)
|
||
return
|
||
}
|
||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||
h.writeUpstreamResponse(c, *resp)
|
||
return
|
||
}
|
||
|
||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, "")
|
||
if err != nil {
|
||
h.writeConversionError(c, err, clientProtocol)
|
||
return
|
||
}
|
||
|
||
h.writeConvertedResponse(c, *convertedResp)
|
||
}
|
||
|
||
func (h *ProxyHandler) forwardStream(c *gin.Context, outSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, ifaceType conversion.InterfaceType) {
|
||
streamResp, err := h.client.SendStream(c.Request.Context(), outSpec)
|
||
if err != nil {
|
||
h.writeUpstreamUnavailable(c, err)
|
||
return
|
||
}
|
||
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
|
||
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
|
||
StatusCode: streamResp.StatusCode,
|
||
Headers: streamResp.Headers,
|
||
Body: streamResp.Body,
|
||
})
|
||
return
|
||
}
|
||
|
||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, "", ifaceType)
|
||
if err != nil {
|
||
h.writeConversionError(c, err, clientProtocol)
|
||
return
|
||
}
|
||
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
|
||
writer := bufio.NewWriter(c.Writer)
|
||
flushed := false
|
||
for event := range streamResp.Events {
|
||
if event.Error != nil {
|
||
h.logger.Error("透传流读取错误", zap.Error(event.Error))
|
||
break
|
||
}
|
||
if event.Done {
|
||
chunks := streamConverter.Flush()
|
||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||
}
|
||
flushed = true
|
||
break
|
||
}
|
||
chunks := streamConverter.ProcessChunk(event.Data)
|
||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||
break
|
||
}
|
||
}
|
||
if !flushed {
|
||
chunks := streamConverter.Flush()
|
||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||
}
|
||
}
|
||
}
|
||
|
||
func stripRawQuery(path string) string {
|
||
pathOnly, _, _ := strings.Cut(path, "?")
|
||
return pathOnly
|
||
}
|
||
|
||
func rawQueryFromPath(path string) string {
|
||
_, rawQuery, found := strings.Cut(path, "?")
|
||
if !found {
|
||
return ""
|
||
}
|
||
return rawQuery
|
||
}
|
||
|
||
func joinBaseURL(baseURL, path string) string {
|
||
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
|
||
}
|
||
|
||
func headerValue(headers map[string]string, key string) string {
|
||
for k, v := range headers {
|
||
if strings.EqualFold(k, key) {
|
||
return v
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func filterHopByHopHeaders(headers map[string]string) map[string]string {
|
||
if len(headers) == 0 {
|
||
return nil
|
||
}
|
||
hopByHop := map[string]struct{}{
|
||
"connection": {},
|
||
"transfer-encoding": {},
|
||
"keep-alive": {},
|
||
"proxy-authenticate": {},
|
||
"proxy-authorization": {},
|
||
"te": {},
|
||
"trailer": {},
|
||
"upgrade": {},
|
||
}
|
||
filtered := make(map[string]string, len(headers))
|
||
for k, v := range headers {
|
||
if _, skip := hopByHop[strings.ToLower(k)]; skip {
|
||
continue
|
||
}
|
||
filtered[k] = v
|
||
}
|
||
return filtered
|
||
}
|
||
|
||
// extractHeaders 从 Gin context 提取请求头
|
||
func extractHeaders(c *gin.Context) map[string]string {
|
||
headers := make(map[string]string)
|
||
for k, vs := range c.Request.Header {
|
||
if len(vs) > 0 {
|
||
headers[k] = vs[0]
|
||
}
|
||
}
|
||
return headers
|
||
}
|