package handler import ( "bufio" "encoding/json" "io" "net/http" "strings" "github.com/gin-gonic/gin" "go.uber.org/zap" "nex/backend/internal/conversion" "nex/backend/internal/domain" "nex/backend/internal/provider" "nex/backend/internal/service" ) // ProxyHandler 统一代理处理器 type ProxyHandler struct { engine *conversion.ConversionEngine client provider.ProviderClient routingService service.RoutingService providerService service.ProviderService statsService service.StatsService logger *zap.Logger } // NewProxyHandler 创建统一代理处理器 func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService) *ProxyHandler { return &ProxyHandler{ engine: engine, client: client, routingService: routingService, providerService: providerService, statsService: statsService, logger: zap.L(), } } // HandleProxy 处理代理请求 func (h *ProxyHandler) HandleProxy(c *gin.Context) { // 从 URL 提取 clientProtocol: /{protocol}/v1/... clientProtocol := c.Param("protocol") if clientProtocol == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"}) return } // 原始路径: /v1/{path} path := c.Param("path") if strings.HasPrefix(path, "/") { path = path[1:] } nativePath := "/v1/" + path // 读取请求体 body, err := io.ReadAll(c.Request.Body) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"}) return } // 解析 model 名称(从 JSON body 中提取,GET 请求无 body) modelName := "" if len(body) > 0 { modelName = extractModelName(body) } // 构建输入 HTTPRequestSpec inSpec := conversion.HTTPRequestSpec{ URL: nativePath, Method: c.Request.Method, Headers: extractHeaders(c), Body: body, } // 路由 routeResult, err := h.routingService.Route(modelName) if err != nil { // GET 请求或无法提取 model 时,直接转发到上游 if len(body) == 0 || modelName == "" { h.forwardPassthrough(c, inSpec, clientProtocol) return } h.writeError(c, err, clientProtocol) return } // 确定 providerProtocol providerProtocol := routeResult.Provider.Protocol if providerProtocol == "" { providerProtocol = "openai" } // 构建 TargetProvider targetProvider := conversion.NewTargetProvider( routeResult.Provider.BaseURL, routeResult.Provider.APIKey, routeResult.Model.ModelName, ) // 判断是否流式 isStream := h.isStreamRequest(body, clientProtocol, nativePath) if isStream { h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult) } else { h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult) } } // handleNonStream 处理非流式请求 func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) { // 转换请求 outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) if err != nil { h.logger.Error("转换请求失败", zap.String("error", err.Error())) h.writeConversionError(c, err, clientProtocol) return } // 发送请求 resp, err := h.client.Send(c.Request.Context(), *outSpec) if err != nil { h.logger.Error("发送请求失败", zap.String("error", err.Error())) h.writeConversionError(c, err, clientProtocol) return } // 转换响应 interfaceType, _ := h.engine.DetectInterfaceType(inSpec.URL, clientProtocol) convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, interfaceType) if err != nil { h.logger.Error("转换响应失败", zap.String("error", err.Error())) h.writeConversionError(c, err, clientProtocol) return } // 设置响应头 for k, v := range convertedResp.Headers { c.Header(k, v) } if c.GetHeader("Content-Type") == "" { c.Header("Content-Type", "application/json") } c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body) go func() { _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) }() } // handleStream 处理流式请求 func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) { // 转换请求 outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) if err != nil { h.writeConversionError(c, err, clientProtocol) return } // 创建流式转换器 streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol) if err != nil { h.writeConversionError(c, err, clientProtocol) return } // 发送流式请求 eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec) if err != nil { h.writeConversionError(c, err, clientProtocol) return } c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") writer := bufio.NewWriter(c.Writer) for event := range eventChan { if event.Error != nil { h.logger.Error("流读取错误", zap.String("error", event.Error.Error())) break } if event.Done { // flush 转换器 chunks := streamConverter.Flush() for _, chunk := range chunks { writer.Write(chunk) writer.Flush() } break } chunks := streamConverter.ProcessChunk(event.Data) for _, chunk := range chunks { writer.Write(chunk) writer.Flush() } } go func() { _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) }() } // isStreamRequest 判断是否流式请求 func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool { ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol) if ifaceType != conversion.InterfaceTypeChat { return false } var req struct { Stream bool `json:"stream"` } if err := json.Unmarshal(body, &req); err != nil { return false } return req.Stream } // writeConversionError 写入转换错误 func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) { if convErr, ok := err.(*conversion.ConversionError); ok { body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol) c.Data(statusCode, "application/json", body) return } c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } // writeError 写入路由错误 func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) } // forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求) func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) { registry := h.engine.GetRegistry() adapter, err := registry.Get(clientProtocol) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol}) return } providers, err := h.providerService.List() if err != nil || len(providers) == 0 { h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL)) c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"}) return } p := providers[0] providerProtocol := p.Protocol if providerProtocol == "" { providerProtocol = "openai" } ifaceType := adapter.DetectInterfaceType(inSpec.URL) targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "") var outSpec *conversion.HTTPRequestSpec if clientProtocol == providerProtocol { upstreamURL := p.BaseURL + inSpec.URL headers := adapter.BuildHeaders(targetProvider) if _, ok := headers["Content-Type"]; !ok { headers["Content-Type"] = "application/json" } outSpec = &conversion.HTTPRequestSpec{ URL: upstreamURL, Method: inSpec.Method, Headers: headers, Body: inSpec.Body, } } else { outSpec, err = h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) if err != nil { h.writeConversionError(c, err, clientProtocol) return } } resp, err := h.client.Send(c.Request.Context(), *outSpec) if err != nil { h.writeConversionError(c, err, clientProtocol) return } convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType) if err != nil { h.writeConversionError(c, err, clientProtocol) return } for k, v := range convertedResp.Headers { c.Header(k, v) } if c.GetHeader("Content-Type") == "" { c.Header("Content-Type", "application/json") } c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body) } // extractModelName 从 JSON body 中提取 model func extractModelName(body []byte) string { var req struct { Model string `json:"model"` } if err := json.Unmarshal(body, &req); err != nil { return "" } return req.Model } // extractHeaders 从 Gin context 提取请求头 func extractHeaders(c *gin.Context) map[string]string { headers := make(map[string]string) for k, vs := range c.Request.Header { if len(vs) > 0 { headers[k] = vs[0] } } return headers }