package handler import ( "bufio" "fmt" "net/http" "github.com/gin-gonic/gin" "nex/backend/internal/config" "nex/backend/internal/protocol/openai" "nex/backend/internal/provider" "nex/backend/internal/router" ) // OpenAIHandler OpenAI 协议处理器 type OpenAIHandler struct { client *provider.Client router *router.Router } // NewOpenAIHandler 创建 OpenAI 处理器 func NewOpenAIHandler() *OpenAIHandler { return &OpenAIHandler{ client: provider.NewClient(), router: router.NewRouter(), } } // HandleChatCompletions 处理 Chat Completions 请求 func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) { // 解析请求 var req openai.ChatCompletionRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, openai.ErrorResponse{ Error: openai.ErrorDetail{ Message: "无效的请求格式: " + err.Error(), Type: "invalid_request_error", }, }) return } // 路由到供应商 routeResult, err := h.router.Route(req.Model) if err != nil { h.handleError(c, err) return } // 根据是否流式选择处理方式 if req.Stream { h.handleStreamRequest(c, &req, routeResult) } else { h.handleNonStreamRequest(c, &req, routeResult) } } // handleNonStreamRequest 处理非流式请求 func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) { // 发送请求到供应商 resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL) if err != nil { c.JSON(http.StatusInternalServerError, openai.ErrorResponse{ Error: openai.ErrorDetail{ Message: "供应商请求失败: " + err.Error(), Type: "api_error", }, }) return } // 记录统计 go func() { _ = config.RecordRequest(routeResult.Provider.ID, req.Model) }() // 返回响应 c.JSON(http.StatusOK, resp) } // handleStreamRequest 处理流式请求 func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) { // 发送流式请求到供应商 eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL) if err != nil { c.JSON(http.StatusInternalServerError, openai.ErrorResponse{ Error: openai.ErrorDetail{ Message: "供应商请求失败: " + err.Error(), Type: "api_error", }, }) return } // 设置 SSE 响应头 c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") // 创建流写入器 writer := bufio.NewWriter(c.Writer) // 流式转发事件 for event := range eventChan { if event.Error != nil { // 流错误,记录日志 fmt.Printf("流错误: %v\n", event.Error) break } if event.Done { // 流结束 writer.WriteString("data: [DONE]\n\n") writer.Flush() break } // 写入事件数据 writer.WriteString("data: ") writer.Write(event.Data) writer.WriteString("\n\n") writer.Flush() } // 记录统计 go func() { _ = config.RecordRequest(routeResult.Provider.ID, req.Model) }() } // handleError 处理路由错误 func (h *OpenAIHandler) handleError(c *gin.Context, err error) { switch err { case router.ErrModelNotFound: c.JSON(http.StatusNotFound, openai.ErrorResponse{ Error: openai.ErrorDetail{ Message: "模型未找到", Type: "invalid_request_error", Code: "model_not_found", }, }) case router.ErrModelDisabled: c.JSON(http.StatusNotFound, openai.ErrorResponse{ Error: openai.ErrorDetail{ Message: "模型已禁用", Type: "invalid_request_error", Code: "model_disabled", }, }) case router.ErrProviderDisabled: c.JSON(http.StatusNotFound, openai.ErrorResponse{ Error: openai.ErrorDetail{ Message: "供应商已禁用", Type: "invalid_request_error", Code: "provider_disabled", }, }) default: c.JSON(http.StatusInternalServerError, openai.ErrorResponse{ Error: openai.ErrorDetail{ Message: "内部错误: " + err.Error(), Type: "internal_error", }, }) } }