1
0

feat: 实现分层架构,包含 domain、service、repository 和 pkg 层

- 新增 domain 层:model、provider、route、stats 实体
- 新增 service 层:models、providers、routing、stats 业务逻辑
- 新增 repository 层:models、providers、stats 数据访问
- 新增 pkg 工具包:errors、logger、validator
- 新增中间件:CORS、logging、recovery、request ID
- 新增数据库迁移:初始 schema 和索引
- 新增单元测试和集成测试
- 新增规范文档:config-management、database-migration、error-handling、layered-architecture、middleware-system、request-validation、structured-logging、test-coverage
- 移除 config 子包和 model_router(已迁移至分层架构)
This commit is contained in:
2026-04-16 00:47:20 +08:00
parent 915b004924
commit f18904af1e
77 changed files with 5727 additions and 1257 deletions

View File

@@ -7,30 +7,33 @@ import (
"github.com/gin-gonic/gin"
"nex/backend/internal/config"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/protocol/anthropic"
"nex/backend/internal/protocol/openai"
"nex/backend/internal/provider"
"nex/backend/internal/router"
"nex/backend/internal/service"
)
// AnthropicHandler Anthropic 协议处理器
type AnthropicHandler struct {
client *provider.Client
router *router.Router
client provider.ProviderClient
routingService service.RoutingService
statsService service.StatsService
}
// NewAnthropicHandler 创建 Anthropic 处理器
func NewAnthropicHandler() *AnthropicHandler {
func NewAnthropicHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *AnthropicHandler {
return &AnthropicHandler{
client: provider.NewClient(),
router: router.NewRouter(),
client: client,
routingService: routingService,
statsService: statsService,
}
}
// HandleMessages 处理 Messages 请求
func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
// 解析 Anthropic 请求
var req anthropic.MessagesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
@@ -43,7 +46,19 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
return
}
// 检查多模态内容
// 请求验证
if validationErrors := anthropic.ValidateRequest(&req); validationErrors != nil {
errMsg := formatValidationErrors(validationErrors)
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "invalid_request_error",
Message: errMsg,
},
})
return
}
if err := h.checkMultimodalContent(&req); err != nil {
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
Type: "error",
@@ -55,7 +70,6 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
return
}
// 转换为 OpenAI 请求
openaiReq, err := anthropic.ConvertRequest(&req)
if err != nil {
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
@@ -68,14 +82,12 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
return
}
// 路由到供应商
routeResult, err := h.router.Route(openaiReq.Model)
routeResult, err := h.routingService.Route(openaiReq.Model)
if err != nil {
h.handleError(c, err)
return
}
// 根据是否流式选择处理方式
if req.Stream {
h.handleStreamRequest(c, openaiReq, routeResult)
} else {
@@ -83,9 +95,7 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
}
}
// handleNonStreamRequest 处理非流式请求
func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送请求到供应商
func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
openaiResp, err := h.client.SendRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
@@ -98,7 +108,6 @@ func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *ope
return
}
// 转换为 Anthropic 响应
anthropicResp, err := anthropic.ConvertResponse(openaiResp)
if err != nil {
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
@@ -111,18 +120,14 @@ func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *ope
return
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
_ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model)
}()
// 返回响应
c.JSON(http.StatusOK, anthropicResp)
}
// handleStreamRequest 处理流式请求
func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送流式请求到供应商
func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
@@ -135,24 +140,19 @@ func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai
return
}
// 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
// 创建流写入器
writer := bufio.NewWriter(c.Writer)
// 创建流式转换器
converter := anthropic.NewStreamConverter(
fmt.Sprintf("msg_%s", routeResult.Provider.ID),
openaiReq.Model,
)
// 流式转发事件
for event := range eventChan {
if event.Error != nil {
fmt.Printf("流错误: %v\n", event.Error)
break
}
@@ -160,25 +160,19 @@ func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai
break
}
// 解析 OpenAI 流块
chunk, err := openai.NewAdapter().ParseStreamChunk(event.Data)
if err != nil {
fmt.Printf("解析流块失败: %v\n", err)
continue
}
// 转换为 Anthropic 事件
anthropicEvents, err := converter.ConvertChunk(chunk)
if err != nil {
fmt.Printf("转换事件失败: %v\n", err)
continue
}
// 写入事件
for _, ae := range anthropicEvents {
eventStr, err := anthropic.SerializeEvent(ae)
if err != nil {
fmt.Printf("序列化事件失败: %v\n", err)
continue
}
writer.WriteString(eventStr)
@@ -186,13 +180,11 @@ func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai
}
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
_ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model)
}()
}
// checkMultimodalContent 检查多模态内容
func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest) error {
for _, msg := range req.Messages {
for _, block := range msg.Content {
@@ -204,40 +196,22 @@ func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest
return nil
}
// handleError 处理路由错误
func (h *AnthropicHandler) handleError(c *gin.Context, err error) {
switch err {
case router.ErrModelNotFound:
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
if appErr, ok := appErrors.AsAppError(err); ok {
c.JSON(appErr.HTTPStatus, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "not_found_error",
Message: "模型未找到",
},
})
case router.ErrModelDisabled:
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "not_found_error",
Message: "模型已禁用",
},
})
case router.ErrProviderDisabled:
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "not_found_error",
Message: "供应商已禁用",
},
})
default:
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "internal_error",
Message: "内部错误: " + err.Error(),
Message: appErr.Message,
},
})
return
}
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "internal_error",
Message: "内部错误: " + err.Error(),
},
})
}

View File

@@ -0,0 +1,290 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/domain"
"nex/backend/internal/protocol/openai"
"nex/backend/internal/provider"
appErrors "nex/backend/pkg/errors"
)
func init() {
gin.SetMode(gin.TestMode)
}
// ============ Mock 实现 ============
type mockRoutingService struct {
result *domain.RouteResult
err error
}
func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error) {
return m.result, m.err
}
type mockStatsService struct {
err error
stats []domain.UsageStats
aggrResult []map[string]interface{}
}
func (m *mockStatsService) Record(providerID, modelName string) error {
return m.err
}
func (m *mockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
return m.stats, nil
}
func (m *mockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} {
return m.aggrResult
}
type mockProviderService struct {
provider *domain.Provider
providers []domain.Provider
err error
}
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
return m.provider, m.err
}
func (m *mockProviderService) List() ([]domain.Provider, error) { return m.providers, m.err }
func (m *mockProviderService) Update(id string, updates map[string]interface{}) error {
return m.err
}
func (m *mockProviderService) Delete(id string) error { return m.err }
type mockModelService struct {
model *domain.Model
models []domain.Model
err error
}
func (m *mockModelService) Create(model *domain.Model) error { return m.err }
func (m *mockModelService) Get(id string) (*domain.Model, error) {
return m.model, m.err
}
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
return m.models, m.err
}
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
return m.err
}
func (m *mockModelService) Delete(id string) error { return m.err }
type mockProviderClient struct {
resp *openai.ChatCompletionResponse
eventChan chan provider.StreamEvent
err error
}
func (m *mockProviderClient) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) {
return m.resp, m.err
}
func (m *mockProviderClient) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan provider.StreamEvent, error) {
return m.eventChan, m.err
}
// ============ OpenAI Handler 测试 ============
func TestOpenAIHandler_HandleChatCompletions_InvalidJSON(t *testing.T) {
h := NewOpenAIHandler(nil, nil, nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte("invalid")))
h.HandleChatCompletions(c)
assert.Equal(t, 400, w.Code)
}
func TestOpenAIHandler_HandleChatCompletions_ValidationError(t *testing.T) {
h := NewOpenAIHandler(nil, nil, nil)
// 缺少 model 字段
body, _ := json.Marshal(map[string]interface{}{
"messages": []map[string]string{{"role": "user", "content": "hi"}},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.HandleChatCompletions(c)
assert.Equal(t, 400, w.Code)
}
func TestOpenAIHandler_HandleChatCompletions_RouteError(t *testing.T) {
routingSvc := &mockRoutingService{err: appErrors.ErrModelNotFound}
h := NewOpenAIHandler(nil, routingSvc, nil)
body, _ := json.Marshal(map[string]interface{}{
"model": "nonexistent",
"messages": []map[string]string{{"role": "user", "content": "hi"}},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.HandleChatCompletions(c)
assert.Equal(t, 404, w.Code)
}
// ============ Provider Handler 测试 ============
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
body, _ := json.Marshal(map[string]string{"id": "p1"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateProvider(c)
assert.Equal(t, 400, w.Code)
}
func TestProviderHandler_ListProviders(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
providers: []domain.Provider{
{ID: "p1", Name: "P1"},
{ID: "p2", Name: "P2"},
},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/api/providers", nil)
h.ListProviders(c)
assert.Equal(t, 200, w.Code)
var result []domain.Provider
json.Unmarshal(w.Body.Bytes(), &result)
assert.Len(t, result, 2)
}
func TestProviderHandler_GetProvider(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
provider: &domain.Provider{ID: "p1", Name: "P1", APIKey: "***"},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "p1"}}
c.Request = httptest.NewRequest("GET", "/api/providers/p1", nil)
h.GetProvider(c)
assert.Equal(t, 200, w.Code)
}
// ============ Model Handler 测试 ============
func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
h := NewModelHandler(&mockModelService{})
body, _ := json.Marshal(map[string]string{"id": "m1"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 400, w.Code)
}
func TestModelHandler_ListModels(t *testing.T) {
h := NewModelHandler(&mockModelService{
models: []domain.Model{
{ID: "m1", ModelName: "gpt-4"},
{ID: "m2", ModelName: "gpt-3.5"},
},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/api/models", nil)
h.ListModels(c)
assert.Equal(t, 200, w.Code)
}
// ============ Stats Handler 测试 ============
func TestStatsHandler_GetStats(t *testing.T) {
h := NewStatsHandler(&mockStatsService{
stats: []domain.UsageStats{
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/api/stats", nil)
h.GetStats(c)
assert.Equal(t, 200, w.Code)
}
func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
h := NewStatsHandler(&mockStatsService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/api/stats?start_date=invalid", nil)
h.GetStats(c)
assert.Equal(t, 400, w.Code)
}
func TestStatsHandler_AggregateStats(t *testing.T) {
h := NewStatsHandler(&mockStatsService{
stats: []domain.UsageStats{
{ProviderID: "p1", RequestCount: 10},
},
aggrResult: []map[string]interface{}{
{"provider_id": "p1", "request_count": 10},
},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/api/stats/aggregate?group_by=provider", nil)
h.AggregateStats(c)
assert.Equal(t, 200, w.Code)
}
// ============ writeError 测试 ============
func TestWriteError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/", nil)
writeError(c, appErrors.ErrModelNotFound)
assert.Equal(t, 404, w.Code)
}
func TestFormatValidationErrors(t *testing.T) {
errs := map[string]string{
"model": "模型名称不能为空",
"messages": "消息列表不能为空",
}
result := formatValidationErrors(errs)
require.Contains(t, result, "请求验证失败")
require.Contains(t, result, "model")
require.Contains(t, result, "messages")
}

View File

@@ -0,0 +1,21 @@
package middleware
import (
"github.com/gin-gonic/gin"
)
// CORS 跨域中间件
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Request-ID")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}

View File

@@ -0,0 +1,40 @@
package middleware
import (
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// Logging 日志中间件
func Logging(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
query := c.Request.URL.RawQuery
requestID, _ := c.Get(RequestIDKey)
logger.Info("请求开始",
zap.String("method", c.Request.Method),
zap.String("path", path),
zap.String("query", query),
zap.String("client_ip", c.ClientIP()),
zap.Any("request_id", requestID),
)
c.Next()
latency := time.Since(start)
statusCode := c.Writer.Status()
logger.Info("请求结束",
zap.Int("status", statusCode),
zap.String("method", c.Request.Method),
zap.String("path", path),
zap.Duration("latency", latency),
zap.Int("body_size", c.Writer.Size()),
zap.Any("request_id", requestID),
)
}
}

View File

@@ -0,0 +1,130 @@
package middleware
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
func init() {
gin.SetMode(gin.TestMode)
}
func TestRequestID_GeneratesUUID(t *testing.T) {
r := gin.New()
r.Use(RequestID())
r.GET("/test", func(c *gin.Context) {
id, exists := c.Get(RequestIDKey)
assert.True(t, exists)
assert.NotEmpty(t, id)
c.Status(200)
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.NotEmpty(t, w.Header().Get("X-Request-ID"))
}
func TestRequestID_UsesExistingHeader(t *testing.T) {
r := gin.New()
r.Use(RequestID())
r.GET("/test", func(c *gin.Context) {
id, _ := c.Get(RequestIDKey)
assert.Equal(t, "existing-id-123", id)
c.Status(200)
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Request-ID", "existing-id-123")
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "existing-id-123", w.Header().Get("X-Request-ID"))
}
func TestLogging(t *testing.T) {
logger := zap.NewNop()
r := gin.New()
r.Use(Logging(logger))
r.GET("/test", func(c *gin.Context) {
c.Status(200)
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test?key=value", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
func TestRecovery_NoPanic(t *testing.T) {
logger := zap.NewNop()
r := gin.New()
r.Use(Recovery(logger))
r.GET("/test", func(c *gin.Context) {
c.Status(200)
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
func TestRecovery_WithPanic(t *testing.T) {
logger := zap.NewNop()
r := gin.New()
r.Use(Recovery(logger))
r.GET("/test", func(c *gin.Context) {
panic("test panic")
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 500, w.Code)
}
func TestCORS_NormalRequest(t *testing.T) {
r := gin.New()
r.Use(CORS())
r.GET("/test", func(c *gin.Context) {
c.Status(200)
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
assert.Contains(t, w.Header().Get("Access-Control-Allow-Methods"), "GET")
assert.Contains(t, w.Header().Get("Access-Control-Allow-Methods"), "POST")
}
func TestCORS_PreflightRequest(t *testing.T) {
r := gin.New()
r.Use(CORS())
r.OPTIONS("/test", func(c *gin.Context) {
c.Status(200)
})
w := httptest.NewRecorder()
req := httptest.NewRequest("OPTIONS", "/test", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 204, w.Code)
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
}

View File

@@ -0,0 +1,29 @@
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// Recovery 错误恢复中间件
func Recovery(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
requestID, _ := c.Get(RequestIDKey)
logger.Error("panic recovered",
zap.Any("error", err),
zap.Any("request_id", requestID),
zap.String("path", c.Request.URL.Path),
zap.Stack("stack"),
)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
"error": "内部错误",
})
}
}()
c.Next()
}
}

View File

@@ -0,0 +1,21 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const RequestIDKey = "request_id"
// RequestID 请求 ID 中间件
func RequestID() gin.HandlerFunc {
return func(c *gin.Context) {
requestID := c.GetHeader("X-Request-ID")
if requestID == "" {
requestID = uuid.New().String()
}
c.Set(RequestIDKey, requestID)
c.Header("X-Request-ID", requestID)
c.Next()
}
}

View File

@@ -6,15 +6,20 @@ import (
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"nex/backend/internal/config"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
)
// ModelHandler 模型管理处理器
type ModelHandler struct{}
type ModelHandler struct {
modelService service.ModelService
}
// NewModelHandler 创建模型处理器
func NewModelHandler() *ModelHandler {
return &ModelHandler{}
func NewModelHandler(modelService service.ModelService) *ModelHandler {
return &ModelHandler{modelService: modelService}
}
// CreateModel 创建模型
@@ -32,26 +37,21 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
return
}
// 创建模型对象
model := &config.Model{
model := &domain.Model{
ID: req.ID,
ProviderID: req.ProviderID,
ModelName: req.ModelName,
Enabled: true, // 默认启用
}
// 保存到数据库
err := config.CreateModel(model)
err := h.modelService.Create(model)
if err != nil {
if err.Error() == "供应商不存在" {
if err == appErrors.ErrProviderNotFound {
c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "创建模型失败: " + err.Error(),
})
writeError(c, err)
return
}
@@ -62,11 +62,9 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
func (h *ModelHandler) ListModels(c *gin.Context) {
providerID := c.Query("provider_id")
models, err := config.ListModels(providerID)
models, err := h.modelService.List(providerID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询模型失败: " + err.Error(),
})
writeError(c, err)
return
}
@@ -77,7 +75,7 @@ func (h *ModelHandler) ListModels(c *gin.Context) {
func (h *ModelHandler) GetModel(c *gin.Context) {
id := c.Param("id")
model, err := config.GetModel(id)
model, err := h.modelService.Get(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
@@ -85,9 +83,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询模型失败: " + err.Error(),
})
writeError(c, err)
return
}
@@ -106,8 +102,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
return
}
// 更新模型
err := config.UpdateModel(id, req)
err := h.modelService.Update(id, req)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
@@ -115,24 +110,19 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
})
return
}
if err.Error() == "供应商不存在" {
if err == appErrors.ErrProviderNotFound {
c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "更新模型失败: " + err.Error(),
})
writeError(c, err)
return
}
// 返回更新后的模型
model, err := config.GetModel(id)
model, err := h.modelService.Get(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询更新后的模型失败: " + err.Error(),
})
writeError(c, err)
return
}
@@ -143,7 +133,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
func (h *ModelHandler) DeleteModel(c *gin.Context) {
id := c.Param("id")
err := config.DeleteModel(id)
err := h.modelService.Delete(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
@@ -151,9 +141,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "删除模型失败: " + err.Error(),
})
writeError(c, err)
return
}

View File

@@ -4,32 +4,36 @@ import (
"bufio"
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"nex/backend/internal/config"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/protocol/openai"
"nex/backend/internal/provider"
"nex/backend/internal/router"
"nex/backend/internal/service"
)
// OpenAIHandler OpenAI 协议处理器
type OpenAIHandler struct {
client *provider.Client
router *router.Router
client provider.ProviderClient
routingService service.RoutingService
statsService service.StatsService
}
// NewOpenAIHandler 创建 OpenAI 处理器
func NewOpenAIHandler() *OpenAIHandler {
func NewOpenAIHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *OpenAIHandler {
return &OpenAIHandler{
client: provider.NewClient(),
router: router.NewRouter(),
client: client,
routingService: routingService,
statsService: statsService,
}
}
// HandleChatCompletions 处理 Chat Completions 请求
func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
// 解析请求
var req openai.ChatCompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
@@ -41,14 +45,23 @@ func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
return
}
// 路由到供应商
routeResult, err := h.router.Route(req.Model)
// 请求验证
if validationErrors := openai.ValidateRequest(&req); validationErrors != nil {
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: formatValidationErrors(validationErrors),
Type: "invalid_request_error",
},
})
return
}
routeResult, err := h.routingService.Route(req.Model)
if err != nil {
h.handleError(c, err)
return
}
// 根据是否流式选择处理方式
if req.Stream {
h.handleStreamRequest(c, &req, routeResult)
} else {
@@ -56,9 +69,7 @@ func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
}
}
// handleNonStreamRequest 处理非流式请求
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送请求到供应商
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
@@ -70,18 +81,14 @@ func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatC
return
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
}()
// 返回响应
c.JSON(http.StatusOK, resp)
}
// handleStreamRequest 处理流式请求
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送流式请求到供应商
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
@@ -93,75 +100,58 @@ func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatComp
return
}
// 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
// 创建流写入器
writer := bufio.NewWriter(c.Writer)
// 流式转发事件
for event := range eventChan {
if event.Error != nil {
// 流错误,记录日志
fmt.Printf("流错误: %v\n", event.Error)
break
}
if event.Done {
// 流结束
writer.WriteString("data: [DONE]\n\n")
writer.Flush()
break
}
// 写入事件数据
writer.WriteString("data: ")
writer.Write(event.Data)
writer.WriteString("\n\n")
writer.Flush()
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
}()
}
// handleError 处理路由错误
func (h *OpenAIHandler) handleError(c *gin.Context, err error) {
switch err {
case router.ErrModelNotFound:
c.JSON(http.StatusNotFound, openai.ErrorResponse{
if appErr, ok := appErrors.AsAppError(err); ok {
c.JSON(appErr.HTTPStatus, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "模型未找到",
Message: appErr.Message,
Type: "invalid_request_error",
Code: "model_not_found",
},
})
case router.ErrModelDisabled:
c.JSON(http.StatusNotFound, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "模型已禁用",
Type: "invalid_request_error",
Code: "model_disabled",
},
})
case router.ErrProviderDisabled:
c.JSON(http.StatusNotFound, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "供应商已禁用",
Type: "invalid_request_error",
Code: "provider_disabled",
},
})
default:
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "内部错误: " + err.Error(),
Type: "internal_error",
Code: appErr.Code,
},
})
return
}
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "内部错误: " + err.Error(),
Type: "internal_error",
},
})
}
// formatValidationErrors 将验证错误 map 格式化为字符串
func formatValidationErrors(errors map[string]string) string {
parts := make([]string, 0, len(errors))
for field, msg := range errors {
parts = append(parts, fmt.Sprintf("%s: %s", field, msg))
}
return "请求验证失败: " + strings.Join(parts, "; ")
}

View File

@@ -2,19 +2,25 @@ package handler
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"nex/backend/internal/config"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
)
// ProviderHandler 供应商管理处理器
type ProviderHandler struct{}
type ProviderHandler struct {
providerService service.ProviderService
}
// NewProviderHandler 创建供应商处理器
func NewProviderHandler() *ProviderHandler {
return &ProviderHandler{}
func NewProviderHandler(providerService service.ProviderService) *ProviderHandler {
return &ProviderHandler{providerService: providerService}
}
// CreateProvider 创建供应商
@@ -33,43 +39,34 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
return
}
// 创建供应商对象
provider := &config.Provider{
provider := &domain.Provider{
ID: req.ID,
Name: req.Name,
APIKey: req.APIKey,
BaseURL: req.BaseURL,
Enabled: true, // 默认启用
}
// 保存到数据库
err := config.CreateProvider(provider)
err := h.providerService.Create(provider)
if err != nil {
// 检查是否是唯一约束错误ID 重复)
if err.Error() == "UNIQUE constraint failed: providers.id" {
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
c.JSON(http.StatusConflict, gin.H{
"error": "供应商 ID 已存在",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "创建供应商失败: " + err.Error(),
})
writeError(c, err)
return
}
// 掩码 API Key 后返回
provider.MaskAPIKey()
c.JSON(http.StatusCreated, provider)
}
// ListProviders 列出所有供应商
func (h *ProviderHandler) ListProviders(c *gin.Context) {
providers, err := config.ListProviders()
providers, err := h.providerService.List()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询供应商失败: " + err.Error(),
})
writeError(c, err)
return
}
@@ -80,7 +77,7 @@ func (h *ProviderHandler) ListProviders(c *gin.Context) {
func (h *ProviderHandler) GetProvider(c *gin.Context) {
id := c.Param("id")
provider, err := config.GetProvider(id, true) // 掩码 API Key
provider, err := h.providerService.Get(id, true)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
@@ -88,9 +85,7 @@ func (h *ProviderHandler) GetProvider(c *gin.Context) {
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询供应商失败: " + err.Error(),
})
writeError(c, err)
return
}
@@ -109,8 +104,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
return
}
// 更新供应商
err := config.UpdateProvider(id, req)
err := h.providerService.Update(id, req)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
@@ -118,18 +112,13 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "更新供应商失败: " + err.Error(),
})
writeError(c, err)
return
}
// 返回更新后的供应商
provider, err := config.GetProvider(id, true)
provider, err := h.providerService.Get(id, true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询更新后的供应商失败: " + err.Error(),
})
writeError(c, err)
return
}
@@ -140,8 +129,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
id := c.Param("id")
// 删除供应商(级联删除模型)
err := config.DeleteProvider(id)
err := h.providerService.Delete(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
@@ -149,19 +137,23 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "删除供应商失败: " + err.Error(),
})
writeError(c, err)
return
}
// 删除关联的模型
models, _ := config.ListModels("")
for _, model := range models {
if model.ProviderID == id {
_ = config.DeleteModel(model.ID)
}
}
c.Status(http.StatusNoContent)
}
// writeError 统一错误响应处理
func writeError(c *gin.Context, err error) {
if appErr, ok := appErrors.AsAppError(err); ok {
c.JSON(appErr.HTTPStatus, gin.H{
"error": appErr.Message,
"code": appErr.Code,
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": err.Error(),
})
}

View File

@@ -6,20 +6,21 @@ import (
"github.com/gin-gonic/gin"
"nex/backend/internal/config"
"nex/backend/internal/service"
)
// StatsHandler 统计处理器
type StatsHandler struct{}
type StatsHandler struct {
statsService service.StatsService
}
// NewStatsHandler 创建统计处理器
func NewStatsHandler() *StatsHandler {
return &StatsHandler{}
func NewStatsHandler(statsService service.StatsService) *StatsHandler {
return &StatsHandler{statsService: statsService}
}
// GetStats 查询统计
func (h *StatsHandler) GetStats(c *gin.Context) {
// 解析查询参数
providerID := c.Query("provider_id")
modelName := c.Query("model_name")
startDateStr := c.Query("start_date")
@@ -27,7 +28,6 @@ func (h *StatsHandler) GetStats(c *gin.Context) {
var startDate, endDate *time.Time
// 解析日期
if startDateStr != "" {
t, err := time.Parse("2006-01-02", startDateStr)
if err != nil {
@@ -50,8 +50,7 @@ func (h *StatsHandler) GetStats(c *gin.Context) {
endDate = &t
}
// 查询统计
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
stats, err := h.statsService.Get(providerID, modelName, startDate, endDate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询统计失败: " + err.Error(),
@@ -64,16 +63,14 @@ func (h *StatsHandler) GetStats(c *gin.Context) {
// AggregateStats 聚合统计
func (h *StatsHandler) AggregateStats(c *gin.Context) {
// 解析查询参数
providerID := c.Query("provider_id")
modelName := c.Query("model_name")
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
groupBy := c.Query("group_by") // "provider", "model", "date"
groupBy := c.Query("group_by")
var startDate, endDate *time.Time
// 解析日期
if startDateStr != "" {
t, err := time.Parse("2006-01-02", startDateStr)
if err != nil {
@@ -96,8 +93,7 @@ func (h *StatsHandler) AggregateStats(c *gin.Context) {
endDate = &t
}
// 查询统计
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
stats, err := h.statsService.Get(providerID, modelName, startDate, endDate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询统计失败: " + err.Error(),
@@ -105,80 +101,6 @@ func (h *StatsHandler) AggregateStats(c *gin.Context) {
return
}
// 聚合
result := h.aggregate(stats, groupBy)
result := h.statsService.Aggregate(stats, groupBy)
c.JSON(http.StatusOK, result)
}
// aggregate 执行聚合
func (h *StatsHandler) aggregate(stats []config.UsageStats, groupBy string) []map[string]interface{} {
switch groupBy {
case "provider":
return h.aggregateByProvider(stats)
case "model":
return h.aggregateByModel(stats)
case "date":
return h.aggregateByDate(stats)
default:
// 默认按供应商聚合
return h.aggregateByProvider(stats)
}
}
// aggregateByProvider 按供应商聚合
func (h *StatsHandler) aggregateByProvider(stats []config.UsageStats) []map[string]interface{} {
aggregated := make(map[string]int)
for _, stat := range stats {
aggregated[stat.ProviderID] += stat.RequestCount
}
result := make([]map[string]interface{}, 0, len(aggregated))
for providerID, count := range aggregated {
result = append(result, map[string]interface{}{
"provider_id": providerID,
"request_count": count,
})
}
return result
}
// aggregateByModel 按模型聚合
func (h *StatsHandler) aggregateByModel(stats []config.UsageStats) []map[string]interface{} {
aggregated := make(map[string]int)
for _, stat := range stats {
key := stat.ProviderID + "/" + stat.ModelName
aggregated[key] += stat.RequestCount
}
result := make([]map[string]interface{}, 0, len(aggregated))
for key, count := range aggregated {
result = append(result, map[string]interface{}{
"provider_id": key[:len(key)/2],
"model_name": key[len(key)/2+1:],
"request_count": count,
})
}
return result
}
// aggregateByDate 按日期聚合
func (h *StatsHandler) aggregateByDate(stats []config.UsageStats) []map[string]interface{} {
aggregated := make(map[string]int)
for _, stat := range stats {
key := stat.Date.Format("2006-01-02")
aggregated[key] += stat.RequestCount
}
result := make([]map[string]interface{}, 0, len(aggregated))
for date, count := range aggregated {
result = append(result, map[string]interface{}{
"date": date,
"request_count": count,
})
}
return result
}