package handler import ( "bufio" "encoding/json" "errors" "io" "net/http" "strings" "nex/backend/internal/conversion" "nex/backend/internal/conversion/canonical" "nex/backend/internal/domain" "nex/backend/internal/provider" "nex/backend/internal/service" appErrors "nex/backend/pkg/errors" "nex/backend/pkg/modelid" "github.com/gin-gonic/gin" "go.uber.org/zap" pkglogger "nex/backend/pkg/logger" ) // ProxyHandler 统一代理处理器 type ProxyHandler struct { engine *conversion.ConversionEngine client provider.ProviderClient routingService service.RoutingService providerService service.ProviderService statsService service.StatsService logger *zap.Logger } // NewProxyHandler 创建统一代理处理器 func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService, logger *zap.Logger) *ProxyHandler { return &ProxyHandler{ engine: engine, client: client, routingService: routingService, providerService: providerService, statsService: statsService, logger: pkglogger.WithModule(logger, "handler.proxy"), } } // HandleProxy 处理代理请求 func (h *ProxyHandler) HandleProxy(c *gin.Context) { // 从 URL 提取 clientProtocol: /{protocol}/v1/... clientProtocol := c.Param("protocol") if clientProtocol == "" { h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "缺少协议前缀") return } // 原始路径: /{path} path := c.Param("path") if !strings.HasPrefix(path, "/") { path = "/" + path } nativePath := path requestPath := appendRawQuery(nativePath, c.Request.URL.RawQuery) // 获取 client adapter registry := h.engine.GetRegistry() clientAdapter, err := registry.Get(clientProtocol) if err != nil { h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol) return } // 检测接口类型 ifaceType := clientAdapter.DetectInterfaceType(nativePath) // 处理 Models 接口:本地聚合 if ifaceType == conversion.InterfaceTypeModels { h.handleModelsList(c, clientAdapter) return } // 处理 ModelInfo 接口:本地查询 if ifaceType == conversion.InterfaceTypeModelInfo { unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath) if err != nil { h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的模型 ID 格式") return } h.handleModelInfo(c, unifiedID, clientAdapter) return } // 读取请求体 body, err := io.ReadAll(c.Request.Body) if err != nil { h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "读取请求体失败") return } // 构建输入 HTTPRequestSpec inSpec := conversion.HTTPRequestSpec{ URL: requestPath, Method: c.Request.Method, Headers: extractHeaders(c), Body: body, } isStream := h.isStreamRequest(body, clientProtocol, nativePath) // 只有 adapter 明确适配的接口才提取 model。未知接口不做通用 model 猜测。 if len(body) == 0 || !supportsModelExtraction(ifaceType) { h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream) return } unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType) if err != nil { if isInvalidJSONError(err) { h.writeProxyError(c, http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误") return } h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream) return } providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID) if err != nil { // 原始模型名兼容透传:非统一模型 ID 不参与路由。 h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream) return } if providerID == "" || modelName == "" { h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream) return } // 路由 routeResult, err := h.routingService.RouteByModelName(providerID, modelName) if err != nil { h.writeRouteError(c, err) return } // 确定 providerProtocol providerProtocol := routeResult.Provider.Protocol if providerProtocol == "" { providerProtocol = "openai" } // 构建 TargetProvider // 注意:ModelName 字段用于 Smart Passthrough 场景改写请求体 // 同协议:请求体中的统一 ID 会被改写为 ModelName(上游名) // 跨协议:全量转换时 ModelName 会被编码到请求体中 targetProvider := conversion.NewTargetProvider( routeResult.Provider.BaseURL, routeResult.Provider.APIKey, routeResult.Model.ModelName, // 上游模型名,用于请求改写 ) // 计算统一模型 ID(用于响应覆写) unifiedModelID := routeResult.Model.UnifiedModelID() if isStream { h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType) } else { h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType) } } func supportsModelExtraction(ifaceType conversion.InterfaceType) bool { switch ifaceType { case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank: return true default: return false } } func isInvalidJSONError(err error) bool { var syntaxErr *json.SyntaxError var typeErr *json.UnmarshalTypeError return errors.As(err, &syntaxErr) || errors.As(err, &typeErr) } func appendRawQuery(path, rawQuery string) string { if rawQuery == "" { return path } return path + "?" + rawQuery } // handleNonStream 处理非流式请求 func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) { // 转换请求 outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) if err != nil { h.logger.Error("转换请求失败", zap.Error(err)) h.writeConversionError(c, err, clientProtocol) return } // 发送请求 resp, err := h.client.Send(c.Request.Context(), *outSpec) if err != nil { h.logger.Error("发送请求失败", zap.Error(err)) h.writeUpstreamUnavailable(c, err) return } if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { h.writeUpstreamResponse(c, *resp) return } // 转换响应,传入 modelOverride(跨协议场景覆写 model 字段) convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID) if err != nil { h.logger.Error("转换响应失败", zap.Error(err)) h.writeConversionError(c, err, clientProtocol) return } h.writeConvertedResponse(c, *convertedResp) go func() { _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求 }() } // handleStream 处理流式请求 func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) { // 转换请求 outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) if err != nil { h.writeConversionError(c, err, clientProtocol) return } // 发送流式请求 streamResp, err := h.client.SendStream(c.Request.Context(), *outSpec) if err != nil { h.writeUpstreamUnavailable(c, err) return } if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices { h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{ StatusCode: streamResp.StatusCode, Headers: streamResp.Headers, Body: streamResp.Body, }) return } // 创建流式转换器,传入 modelOverride(跨协议场景覆写 model 字段) streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType) if err != nil { h.writeConversionError(c, err, clientProtocol) return } c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") writer := bufio.NewWriter(c.Writer) flushed := false for event := range streamResp.Events { if event.Error != nil { h.logger.Error("流读取错误", zap.Error(event.Error)) break } if event.Done { // flush 转换器 chunks := streamConverter.Flush() if err := h.writeStreamChunks(writer, chunks); err != nil { h.logger.Warn("流式响应写回失败", zap.Error(err)) } flushed = true break } chunks := streamConverter.ProcessChunk(event.Data) if err := h.writeStreamChunks(writer, chunks); err != nil { h.logger.Warn("流式响应写回失败", zap.Error(err)) break } } if !flushed { chunks := streamConverter.Flush() if err := h.writeStreamChunks(writer, chunks); err != nil { h.logger.Warn("流式响应写回失败", zap.Error(err)) } } go func() { _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求 }() } func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error { for _, chunk := range chunks { if _, err := writer.Write(chunk); err != nil { return err } if err := writer.Flush(); err != nil { return err } } return nil } // isStreamRequest 判断是否流式请求 func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool { ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol) if err != nil { return false } if ifaceType != conversion.InterfaceTypeChat { return false } var req struct { Stream bool `json:"stream"` } if err := json.Unmarshal(body, &req); err != nil { return false } return req.Stream } // handleModelsList 处理 GET /v1/models 本地聚合 func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.ProtocolAdapter) { // 从数据库查询所有启用的模型 models, err := h.providerService.ListEnabledModels() if err != nil { h.logger.Error("查询启用模型失败", zap.Error(err)) h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "查询模型失败") return } // 构建 CanonicalModelList modelList := &canonical.CanonicalModelList{ Models: make([]canonical.CanonicalModel, 0, len(models)), } for _, m := range models { modelList.Models = append(modelList.Models, canonical.CanonicalModel{ ID: m.UnifiedModelID(), Name: m.ModelName, Created: m.CreatedAt.Unix(), OwnedBy: m.ProviderID, }) } // 使用 adapter 编码返回 body, err := adapter.EncodeModelsResponse(modelList) if err != nil { h.logger.Error("编码 Models 响应失败", zap.Error(err)) h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败") return } c.Data(http.StatusOK, "application/json", body) } // handleModelInfo 处理 GET /v1/models/{unified_id} 本地查询 func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter conversion.ProtocolAdapter) { // 解析统一模型 ID providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID) if err != nil { h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的统一模型 ID 格式") return } // 从数据库查询模型 model, err := h.providerService.GetModelByProviderAndName(providerID, modelName) if err != nil { h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", "模型未找到") return } // 构建 CanonicalModelInfo modelInfo := &canonical.CanonicalModelInfo{ ID: model.UnifiedModelID(), Name: model.ModelName, Created: model.CreatedAt.Unix(), OwnedBy: model.ProviderID, } // 使用 adapter 编码返回 body, err := adapter.EncodeModelInfoResponse(modelInfo) if err != nil { h.logger.Error("编码 ModelInfo 响应失败", zap.Error(err)) h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败") return } c.Data(http.StatusOK, "application/json", body) } // writeConversionError 写入网关层转换错误 func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) { var convErr *conversion.ConversionError if errors.As(err, &convErr) { statusCode, code, message := mapConversionError(convErr) h.writeProxyError(c, statusCode, code, message) return } h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", err.Error()) } func mapConversionError(err *conversion.ConversionError) (int, string, string) { switch err.Code { case conversion.ErrorCodeJSONParseError: if phase, ok := err.Details[conversion.ErrorDetailPhase].(string); ok && phase == conversion.ErrorPhaseRequest { return http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误" } return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message case conversion.ErrorCodeInvalidInput, conversion.ErrorCodeMissingRequiredField, conversion.ErrorCodeProtocolConstraint: return http.StatusBadRequest, "INVALID_REQUEST", err.Message case conversion.ErrorCodeInterfaceNotSupported: return http.StatusBadRequest, "UNSUPPORTED_INTERFACE", err.Message case conversion.ErrorCodeUnsupportedMultimodal: return http.StatusBadRequest, "UNSUPPORTED_MULTIMODAL", err.Message default: return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message } } func (h *ProxyHandler) writeRouteError(c *gin.Context, err error) { if appErr, ok := appErrors.AsAppError(err); ok { switch appErr.Code { case appErrors.ErrModelNotFound.Code, appErrors.ErrModelDisabled.Code: h.writeProxyError(c, appErr.HTTPStatus, "MODEL_NOT_FOUND", appErr.Message) case appErrors.ErrProviderNotFound.Code, appErrors.ErrProviderDisabled.Code: h.writeProxyError(c, appErr.HTTPStatus, "PROVIDER_NOT_FOUND", appErr.Message) default: h.writeProxyError(c, appErr.HTTPStatus, "INVALID_REQUEST", appErr.Message) } return } h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", err.Error()) } func (h *ProxyHandler) writeUpstreamUnavailable(c *gin.Context, err error) { h.logger.Error("上游不可达", zap.Error(err)) h.writeProxyError(c, http.StatusBadGateway, "UPSTREAM_UNAVAILABLE", "上游服务不可达") } func (h *ProxyHandler) writeProxyError(c *gin.Context, status int, code, message string) { c.JSON(status, gin.H{ "error": message, "code": code, }) } func (h *ProxyHandler) writeConvertedResponse(c *gin.Context, resp conversion.HTTPResponseSpec) { for k, v := range resp.Headers { c.Header(k, v) } contentType := headerValue(resp.Headers, "Content-Type") if contentType == "" { contentType = "application/json" } c.Data(resp.StatusCode, contentType, resp.Body) } func (h *ProxyHandler) writeUpstreamResponse(c *gin.Context, resp conversion.HTTPResponseSpec) { for k, v := range filterHopByHopHeaders(resp.Headers) { c.Header(k, v) } contentType := headerValue(resp.Headers, "Content-Type") c.Data(resp.StatusCode, contentType, resp.Body) } // forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求) func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string, ifaceType conversion.InterfaceType, isStream bool) { registry := h.engine.GetRegistry() adapter, err := registry.Get(clientProtocol) if err != nil { h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol) return } providers, err := h.providerService.List() if err != nil || len(providers) == 0 { h.logger.Warn("无可用供应商转发请求", zap.String("path", inSpec.URL)) h.writeProxyError(c, http.StatusNotFound, "PROVIDER_NOT_FOUND", "没有可用的供应商。请先创建供应商和模型。") return } p := providers[0] providerProtocol := p.Protocol if providerProtocol == "" { providerProtocol = "openai" } targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "") var outSpec *conversion.HTTPRequestSpec if clientProtocol == providerProtocol { upstreamPath := adapter.BuildUrl(stripRawQuery(inSpec.URL), ifaceType) upstreamPath = appendRawQuery(upstreamPath, rawQueryFromPath(inSpec.URL)) headers := adapter.BuildHeaders(targetProvider) if _, ok := headers["Content-Type"]; !ok { headers["Content-Type"] = "application/json" } outSpec = &conversion.HTTPRequestSpec{ URL: joinBaseURL(p.BaseURL, upstreamPath), Method: inSpec.Method, Headers: headers, Body: inSpec.Body, } } else { outSpec, err = h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) if err != nil { h.writeConversionError(c, err, clientProtocol) return } } if isStream { h.forwardStream(c, *outSpec, clientProtocol, providerProtocol, ifaceType) return } resp, err := h.client.Send(c.Request.Context(), *outSpec) if err != nil { h.writeUpstreamUnavailable(c, err) return } if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { h.writeUpstreamResponse(c, *resp) return } convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, "") if err != nil { h.writeConversionError(c, err, clientProtocol) return } h.writeConvertedResponse(c, *convertedResp) } func (h *ProxyHandler) forwardStream(c *gin.Context, outSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, ifaceType conversion.InterfaceType) { streamResp, err := h.client.SendStream(c.Request.Context(), outSpec) if err != nil { h.writeUpstreamUnavailable(c, err) return } if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices { h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{ StatusCode: streamResp.StatusCode, Headers: streamResp.Headers, Body: streamResp.Body, }) return } streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, "", ifaceType) if err != nil { h.writeConversionError(c, err, clientProtocol) return } c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") writer := bufio.NewWriter(c.Writer) flushed := false for event := range streamResp.Events { if event.Error != nil { h.logger.Error("透传流读取错误", zap.Error(event.Error)) break } if event.Done { chunks := streamConverter.Flush() if err := h.writeStreamChunks(writer, chunks); err != nil { h.logger.Warn("透传流式响应写回失败", zap.Error(err)) } flushed = true break } chunks := streamConverter.ProcessChunk(event.Data) if err := h.writeStreamChunks(writer, chunks); err != nil { h.logger.Warn("透传流式响应写回失败", zap.Error(err)) break } } if !flushed { chunks := streamConverter.Flush() if err := h.writeStreamChunks(writer, chunks); err != nil { h.logger.Warn("透传流式响应写回失败", zap.Error(err)) } } } func stripRawQuery(path string) string { pathOnly, _, _ := strings.Cut(path, "?") return pathOnly } func rawQueryFromPath(path string) string { _, rawQuery, found := strings.Cut(path, "?") if !found { return "" } return rawQuery } func joinBaseURL(baseURL, path string) string { return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/") } func headerValue(headers map[string]string, key string) string { for k, v := range headers { if strings.EqualFold(k, key) { return v } } return "" } func filterHopByHopHeaders(headers map[string]string) map[string]string { if len(headers) == 0 { return nil } hopByHop := map[string]struct{}{ "connection": {}, "transfer-encoding": {}, "keep-alive": {}, "proxy-authenticate": {}, "proxy-authorization": {}, "te": {}, "trailer": {}, "upgrade": {}, } filtered := make(map[string]string, len(headers)) for k, v := range headers { if _, skip := hopByHop[strings.ToLower(k)]; skip { continue } filtered[k] = v } return filtered } // extractHeaders 从 Gin context 提取请求头 func extractHeaders(c *gin.Context) map[string]string { headers := make(map[string]string) for k, vs := range c.Request.Header { if len(vs) > 0 { headers[k] = vs[0] } } return headers }