1
0
Files
nex/backend/internal/handler/openai_handler.go
lanyuanxiaoyao f18904af1e feat: 实现分层架构,包含 domain、service、repository 和 pkg 层
- 新增 domain 层:model、provider、route、stats 实体
- 新增 service 层:models、providers、routing、stats 业务逻辑
- 新增 repository 层:models、providers、stats 数据访问
- 新增 pkg 工具包:errors、logger、validator
- 新增中间件:CORS、logging、recovery、request ID
- 新增数据库迁移:初始 schema 和索引
- 新增单元测试和集成测试
- 新增规范文档:config-management、database-migration、error-handling、layered-architecture、middleware-system、request-validation、structured-logging、test-coverage
- 移除 config 子包和 model_router(已迁移至分层架构)
2026-04-16 00:47:20 +08:00

158 lines
4.0 KiB
Go

package handler
import (
"bufio"
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/protocol/openai"
"nex/backend/internal/provider"
"nex/backend/internal/service"
)
// OpenAIHandler OpenAI 协议处理器
type OpenAIHandler struct {
client provider.ProviderClient
routingService service.RoutingService
statsService service.StatsService
}
// NewOpenAIHandler 创建 OpenAI 处理器
func NewOpenAIHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *OpenAIHandler {
return &OpenAIHandler{
client: client,
routingService: routingService,
statsService: statsService,
}
}
// HandleChatCompletions 处理 Chat Completions 请求
func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
var req openai.ChatCompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "无效的请求格式: " + err.Error(),
Type: "invalid_request_error",
},
})
return
}
// 请求验证
if validationErrors := openai.ValidateRequest(&req); validationErrors != nil {
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: formatValidationErrors(validationErrors),
Type: "invalid_request_error",
},
})
return
}
routeResult, err := h.routingService.Route(req.Model)
if err != nil {
h.handleError(c, err)
return
}
if req.Stream {
h.handleStreamRequest(c, &req, routeResult)
} else {
h.handleNonStreamRequest(c, &req, routeResult)
}
}
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "供应商请求失败: " + err.Error(),
Type: "api_error",
},
})
return
}
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
}()
c.JSON(http.StatusOK, resp)
}
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "供应商请求失败: " + err.Error(),
Type: "api_error",
},
})
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
writer := bufio.NewWriter(c.Writer)
for event := range eventChan {
if event.Error != nil {
break
}
if event.Done {
writer.WriteString("data: [DONE]\n\n")
writer.Flush()
break
}
writer.WriteString("data: ")
writer.Write(event.Data)
writer.WriteString("\n\n")
writer.Flush()
}
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
}()
}
func (h *OpenAIHandler) handleError(c *gin.Context, err error) {
if appErr, ok := appErrors.AsAppError(err); ok {
c.JSON(appErr.HTTPStatus, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: appErr.Message,
Type: "invalid_request_error",
Code: appErr.Code,
},
})
return
}
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "内部错误: " + err.Error(),
Type: "internal_error",
},
})
}
// formatValidationErrors 将验证错误 map 格式化为字符串
func formatValidationErrors(errors map[string]string) string {
parts := make([]string, 0, len(errors))
for field, msg := range errors {
parts = append(parts, fmt.Sprintf("%s: %s", field, msg))
}
return "请求验证失败: " + strings.Join(parts, "; ")
}