feat: 实现分层架构,包含 domain、service、repository 和 pkg 层
- 新增 domain 层:model、provider、route、stats 实体 - 新增 service 层:models、providers、routing、stats 业务逻辑 - 新增 repository 层:models、providers、stats 数据访问 - 新增 pkg 工具包:errors、logger、validator - 新增中间件:CORS、logging、recovery、request ID - 新增数据库迁移:初始 schema 和索引 - 新增单元测试和集成测试 - 新增规范文档:config-management、database-migration、error-handling、layered-architecture、middleware-system、request-validation、structured-logging、test-coverage - 移除 config 子包和 model_router(已迁移至分层架构)
This commit is contained in:
@@ -4,32 +4,36 @@ import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/router"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// OpenAIHandler OpenAI 协议处理器
|
||||
type OpenAIHandler struct {
|
||||
client *provider.Client
|
||||
router *router.Router
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
statsService service.StatsService
|
||||
}
|
||||
|
||||
// NewOpenAIHandler 创建 OpenAI 处理器
|
||||
func NewOpenAIHandler() *OpenAIHandler {
|
||||
func NewOpenAIHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *OpenAIHandler {
|
||||
return &OpenAIHandler{
|
||||
client: provider.NewClient(),
|
||||
router: router.NewRouter(),
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
statsService: statsService,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleChatCompletions 处理 Chat Completions 请求
|
||||
func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
|
||||
// 解析请求
|
||||
var req openai.ChatCompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
|
||||
@@ -41,14 +45,23 @@ func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 路由到供应商
|
||||
routeResult, err := h.router.Route(req.Model)
|
||||
// 请求验证
|
||||
if validationErrors := openai.ValidateRequest(&req); validationErrors != nil {
|
||||
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: formatValidationErrors(validationErrors),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
routeResult, err := h.routingService.Route(req.Model)
|
||||
if err != nil {
|
||||
h.handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 根据是否流式选择处理方式
|
||||
if req.Stream {
|
||||
h.handleStreamRequest(c, &req, routeResult)
|
||||
} else {
|
||||
@@ -56,9 +69,7 @@ func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStreamRequest 处理非流式请求
|
||||
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送请求到供应商
|
||||
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
@@ -70,18 +81,14 @@ func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatC
|
||||
return
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// handleStreamRequest 处理流式请求
|
||||
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送流式请求到供应商
|
||||
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
@@ -93,75 +100,58 @@ func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatComp
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 SSE 响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
// 创建流写入器
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
// 流式转发事件
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
// 流错误,记录日志
|
||||
fmt.Printf("流错误: %v\n", event.Error)
|
||||
break
|
||||
}
|
||||
|
||||
if event.Done {
|
||||
// 流结束
|
||||
writer.WriteString("data: [DONE]\n\n")
|
||||
writer.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
// 写入事件数据
|
||||
writer.WriteString("data: ")
|
||||
writer.Write(event.Data)
|
||||
writer.WriteString("\n\n")
|
||||
writer.Flush()
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
}
|
||||
|
||||
// handleError 处理路由错误
|
||||
func (h *OpenAIHandler) handleError(c *gin.Context, err error) {
|
||||
switch err {
|
||||
case router.ErrModelNotFound:
|
||||
c.JSON(http.StatusNotFound, openai.ErrorResponse{
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
c.JSON(appErr.HTTPStatus, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "模型未找到",
|
||||
Message: appErr.Message,
|
||||
Type: "invalid_request_error",
|
||||
Code: "model_not_found",
|
||||
},
|
||||
})
|
||||
case router.ErrModelDisabled:
|
||||
c.JSON(http.StatusNotFound, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "模型已禁用",
|
||||
Type: "invalid_request_error",
|
||||
Code: "model_disabled",
|
||||
},
|
||||
})
|
||||
case router.ErrProviderDisabled:
|
||||
c.JSON(http.StatusNotFound, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "供应商已禁用",
|
||||
Type: "invalid_request_error",
|
||||
Code: "provider_disabled",
|
||||
},
|
||||
})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "内部错误: " + err.Error(),
|
||||
Type: "internal_error",
|
||||
Code: appErr.Code,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "内部错误: " + err.Error(),
|
||||
Type: "internal_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// formatValidationErrors 将验证错误 map 格式化为字符串
|
||||
func formatValidationErrors(errors map[string]string) string {
|
||||
parts := make([]string, 0, len(errors))
|
||||
for field, msg := range errors {
|
||||
parts = append(parts, fmt.Sprintf("%s: %s", field, msg))
|
||||
}
|
||||
return "请求验证失败: " + strings.Join(parts, "; ")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user