package handler import ( "bufio" "encoding/json" "io" "net/http" "strings" "github.com/gin-gonic/gin" "go.uber.org/zap" "nex/backend/internal/conversion" "nex/backend/internal/conversion/canonical" "nex/backend/internal/domain" "nex/backend/internal/provider" "nex/backend/internal/service" "nex/backend/pkg/modelid" ) // ProxyHandler 统一代理处理器 type ProxyHandler struct { engine *conversion.ConversionEngine client provider.ProviderClient routingService service.RoutingService providerService service.ProviderService statsService service.StatsService logger *zap.Logger } // NewProxyHandler 创建统一代理处理器 func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService) *ProxyHandler { return &ProxyHandler{ engine: engine, client: client, routingService: routingService, providerService: providerService, statsService: statsService, logger: zap.L(), } } // HandleProxy 处理代理请求 func (h *ProxyHandler) HandleProxy(c *gin.Context) { // 从 URL 提取 clientProtocol: /{protocol}/v1/... clientProtocol := c.Param("protocol") if clientProtocol == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"}) return } // 原始路径: /v1/{path} path := c.Param("path") if strings.HasPrefix(path, "/") { path = path[1:] } nativePath := "/v1/" + path // 获取 client adapter registry := h.engine.GetRegistry() clientAdapter, err := registry.Get(clientProtocol) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol}) return } // 检测接口类型 ifaceType := clientAdapter.DetectInterfaceType(nativePath) // 处理 Models 接口:本地聚合 if ifaceType == conversion.InterfaceTypeModels { h.handleModelsList(c, clientAdapter) return } // 处理 ModelInfo 接口:本地查询 if ifaceType == conversion.InterfaceTypeModelInfo { unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "无效的模型 ID 格式"}) return } h.handleModelInfo(c, unifiedID, clientAdapter) return } // 读取请求体 body, err := io.ReadAll(c.Request.Body) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"}) return } // 解析统一模型 ID(使用 adapter.ExtractModelName) var providerID, modelName string if len(body) > 0 { unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType) if err == nil && unifiedID != "" { pid, mn, err := modelid.ParseUnifiedModelID(unifiedID) if err == nil { providerID = pid modelName = mn } } } // 构建输入 HTTPRequestSpec inSpec := conversion.HTTPRequestSpec{ URL: nativePath, Method: c.Request.Method, Headers: extractHeaders(c), Body: body, } // 路由 routeResult, err := h.routingService.RouteByModelName(providerID, modelName) if err != nil { // GET 请求或无法提取 model 时,直接转发到上游 if len(body) == 0 || modelName == "" { h.forwardPassthrough(c, inSpec, clientProtocol) return } h.writeError(c, err, clientProtocol) return } // 确定 providerProtocol providerProtocol := routeResult.Provider.Protocol if providerProtocol == "" { providerProtocol = "openai" } // 构建 TargetProvider // 注意:ModelName 字段用于 Smart Passthrough 场景改写请求体 // 同协议:请求体中的统一 ID 会被改写为 ModelName(上游名) // 跨协议:全量转换时 ModelName 会被编码到请求体中 targetProvider := conversion.NewTargetProvider( routeResult.Provider.BaseURL, routeResult.Provider.APIKey, routeResult.Model.ModelName, // 上游模型名,用于请求改写 ) // 判断是否流式 isStream := h.isStreamRequest(body, clientProtocol, nativePath) // 计算统一模型 ID(用于响应覆写) unifiedModelID := routeResult.Model.UnifiedModelID() if isStream { h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType) } else { h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType) } } // handleNonStream 处理非流式请求 func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) { // 转换请求 outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) if err != nil { h.logger.Error("转换请求失败", zap.String("error", err.Error())) h.writeConversionError(c, err, clientProtocol) return } // 发送请求 resp, err := h.client.Send(c.Request.Context(), *outSpec) if err != nil { h.logger.Error("发送请求失败", zap.String("error", err.Error())) h.writeConversionError(c, err, clientProtocol) return } // 转换响应,传入 modelOverride(跨协议场景覆写 model 字段) convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID) if err != nil { h.logger.Error("转换响应失败", zap.String("error", err.Error())) h.writeConversionError(c, err, clientProtocol) return } // 设置响应头 for k, v := range convertedResp.Headers { c.Header(k, v) } if c.GetHeader("Content-Type") == "" { c.Header("Content-Type", "application/json") } c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body) go func() { _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) }() } // handleStream 处理流式请求 func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) { // 转换请求 outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) if err != nil { h.writeConversionError(c, err, clientProtocol) return } // 创建流式转换器,传入 modelOverride(跨协议场景覆写 model 字段) streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType) if err != nil { h.writeConversionError(c, err, clientProtocol) return } // 发送流式请求 eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec) if err != nil { h.writeConversionError(c, err, clientProtocol) return } c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") writer := bufio.NewWriter(c.Writer) for event := range eventChan { if event.Error != nil { h.logger.Error("流读取错误", zap.String("error", event.Error.Error())) break } if event.Done { // flush 转换器 chunks := streamConverter.Flush() for _, chunk := range chunks { writer.Write(chunk) writer.Flush() } break } chunks := streamConverter.ProcessChunk(event.Data) for _, chunk := range chunks { writer.Write(chunk) writer.Flush() } } go func() { _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) }() } // isStreamRequest 判断是否流式请求 func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool { ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol) if ifaceType != conversion.InterfaceTypeChat { return false } var req struct { Stream bool `json:"stream"` } if err := json.Unmarshal(body, &req); err != nil { return false } return req.Stream } // handleModelsList 处理 GET /v1/models 本地聚合 func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.ProtocolAdapter) { // 从数据库查询所有启用的模型 models, err := h.providerService.ListEnabledModels() if err != nil { h.logger.Error("查询启用模型失败", zap.String("error", err.Error())) c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"}) return } // 构建 CanonicalModelList modelList := &canonical.CanonicalModelList{ Models: make([]canonical.CanonicalModel, 0, len(models)), } for _, m := range models { modelList.Models = append(modelList.Models, canonical.CanonicalModel{ ID: m.UnifiedModelID(), Name: m.ModelName, Created: m.CreatedAt.Unix(), OwnedBy: m.ProviderID, }) } // 使用 adapter 编码返回 body, err := adapter.EncodeModelsResponse(modelList) if err != nil { h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error())) c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"}) return } c.Data(http.StatusOK, "application/json", body) } // handleModelInfo 处理 GET /v1/models/{unified_id} 本地查询 func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter conversion.ProtocolAdapter) { // 解析统一模型 ID providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ "error": "无效的统一模型 ID 格式", "code": "INVALID_MODEL_ID", }) return } // 从数据库查询模型 model, err := h.providerService.GetModelByProviderAndName(providerID, modelName) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "模型未找到"}) return } // 构建 CanonicalModelInfo modelInfo := &canonical.CanonicalModelInfo{ ID: model.UnifiedModelID(), Name: model.ModelName, Created: model.CreatedAt.Unix(), OwnedBy: model.ProviderID, } // 使用 adapter 编码返回 body, err := adapter.EncodeModelInfoResponse(modelInfo) if err != nil { h.logger.Error("编码 ModelInfo 响应失败", zap.String("error", err.Error())) c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"}) return } c.Data(http.StatusOK, "application/json", body) } // writeConversionError 写入转换错误 func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) { if convErr, ok := err.(*conversion.ConversionError); ok { body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol) c.Data(statusCode, "application/json", body) return } c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } // writeError 写入路由错误 func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) } // forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求) func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) { registry := h.engine.GetRegistry() adapter, err := registry.Get(clientProtocol) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol}) return } providers, err := h.providerService.List() if err != nil || len(providers) == 0 { h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL)) c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"}) return } p := providers[0] providerProtocol := p.Protocol if providerProtocol == "" { providerProtocol = "openai" } ifaceType := adapter.DetectInterfaceType(inSpec.URL) targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "") var outSpec *conversion.HTTPRequestSpec if clientProtocol == providerProtocol { upstreamURL := p.BaseURL + inSpec.URL headers := adapter.BuildHeaders(targetProvider) if _, ok := headers["Content-Type"]; !ok { headers["Content-Type"] = "application/json" } outSpec = &conversion.HTTPRequestSpec{ URL: upstreamURL, Method: inSpec.Method, Headers: headers, Body: inSpec.Body, } } else { outSpec, err = h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) if err != nil { h.writeConversionError(c, err, clientProtocol) return } } resp, err := h.client.Send(c.Request.Context(), *outSpec) if err != nil { h.writeConversionError(c, err, clientProtocol) return } convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, "") if err != nil { h.writeConversionError(c, err, clientProtocol) return } for k, v := range convertedResp.Headers { c.Header(k, v) } if c.GetHeader("Content-Type") == "" { c.Header("Content-Type", "application/json") } c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body) } // extractHeaders 从 Gin context 提取请求头 func extractHeaders(c *gin.Context) map[string]string { headers := make(map[string]string) for k, vs := range c.Request.Header { if len(vs) > 0 { headers[k] = vs[0] } } return headers }