package handler import ( "bufio" "fmt" "net/http" "strings" "github.com/gin-gonic/gin" appErrors "nex/backend/pkg/errors" "nex/backend/internal/domain" "nex/backend/internal/protocol/openai" "nex/backend/internal/provider" "nex/backend/internal/service" ) // OpenAIHandler OpenAI 协议处理器 type OpenAIHandler struct { client provider.ProviderClient routingService service.RoutingService statsService service.StatsService } // NewOpenAIHandler 创建 OpenAI 处理器 func NewOpenAIHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *OpenAIHandler { return &OpenAIHandler{ client: client, routingService: routingService, statsService: statsService, } } // HandleChatCompletions 处理 Chat Completions 请求 func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) { var req openai.ChatCompletionRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, openai.ErrorResponse{ Error: openai.ErrorDetail{ Message: "无效的请求格式: " + err.Error(), Type: "invalid_request_error", }, }) return } // 请求验证 if validationErrors := openai.ValidateRequest(&req); validationErrors != nil { c.JSON(http.StatusBadRequest, openai.ErrorResponse{ Error: openai.ErrorDetail{ Message: formatValidationErrors(validationErrors), Type: "invalid_request_error", }, }) return } routeResult, err := h.routingService.Route(req.Model) if err != nil { h.handleError(c, err) return } if req.Stream { h.handleStreamRequest(c, &req, routeResult) } else { h.handleNonStreamRequest(c, &req, routeResult) } } func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) { resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL) if err != nil { c.JSON(http.StatusInternalServerError, openai.ErrorResponse{ Error: openai.ErrorDetail{ Message: "供应商请求失败: " + err.Error(), Type: "api_error", }, }) return } go func() { _ = h.statsService.Record(routeResult.Provider.ID, req.Model) }() c.JSON(http.StatusOK, resp) } func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) { eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL) if err != nil { c.JSON(http.StatusInternalServerError, openai.ErrorResponse{ Error: openai.ErrorDetail{ Message: "供应商请求失败: " + err.Error(), Type: "api_error", }, }) return } c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") writer := bufio.NewWriter(c.Writer) for event := range eventChan { if event.Error != nil { break } if event.Done { writer.WriteString("data: [DONE]\n\n") writer.Flush() break } writer.WriteString("data: ") writer.Write(event.Data) writer.WriteString("\n\n") writer.Flush() } go func() { _ = h.statsService.Record(routeResult.Provider.ID, req.Model) }() } func (h *OpenAIHandler) handleError(c *gin.Context, err error) { if appErr, ok := appErrors.AsAppError(err); ok { c.JSON(appErr.HTTPStatus, openai.ErrorResponse{ Error: openai.ErrorDetail{ Message: appErr.Message, Type: "invalid_request_error", Code: appErr.Code, }, }) return } c.JSON(http.StatusInternalServerError, openai.ErrorResponse{ Error: openai.ErrorDetail{ Message: "内部错误: " + err.Error(), Type: "internal_error", }, }) } // formatValidationErrors 将验证错误 map 格式化为字符串 func formatValidationErrors(errors map[string]string) string { parts := make([]string, 0, len(errors)) for field, msg := range errors { parts = append(parts, fmt.Sprintf("%s: %s", field, msg)) } return "请求验证失败: " + strings.Join(parts, "; ") }