feat: 实现分层架构,包含 domain、service、repository 和 pkg 层
- 新增 domain 层:model、provider、route、stats 实体 - 新增 service 层:models、providers、routing、stats 业务逻辑 - 新增 repository 层:models、providers、stats 数据访问 - 新增 pkg 工具包:errors、logger、validator - 新增中间件:CORS、logging、recovery、request ID - 新增数据库迁移:初始 schema 和索引 - 新增单元测试和集成测试 - 新增规范文档:config-management、database-migration、error-handling、layered-architecture、middleware-system、request-validation、structured-logging、test-coverage - 移除 config 子包和 model_router(已迁移至分层架构)
This commit is contained in:
@@ -7,30 +7,33 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/anthropic"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/router"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// AnthropicHandler Anthropic 协议处理器
|
||||
type AnthropicHandler struct {
|
||||
client *provider.Client
|
||||
router *router.Router
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
statsService service.StatsService
|
||||
}
|
||||
|
||||
// NewAnthropicHandler 创建 Anthropic 处理器
|
||||
func NewAnthropicHandler() *AnthropicHandler {
|
||||
func NewAnthropicHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *AnthropicHandler {
|
||||
return &AnthropicHandler{
|
||||
client: provider.NewClient(),
|
||||
router: router.NewRouter(),
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
statsService: statsService,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleMessages 处理 Messages 请求
|
||||
func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
|
||||
// 解析 Anthropic 请求
|
||||
var req anthropic.MessagesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
@@ -43,7 +46,19 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查多模态内容
|
||||
// 请求验证
|
||||
if validationErrors := anthropic.ValidateRequest(&req); validationErrors != nil {
|
||||
errMsg := formatValidationErrors(validationErrors)
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: errMsg,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.checkMultimodalContent(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
@@ -55,7 +70,6 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为 OpenAI 请求
|
||||
openaiReq, err := anthropic.ConvertRequest(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
@@ -68,14 +82,12 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 路由到供应商
|
||||
routeResult, err := h.router.Route(openaiReq.Model)
|
||||
routeResult, err := h.routingService.Route(openaiReq.Model)
|
||||
if err != nil {
|
||||
h.handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 根据是否流式选择处理方式
|
||||
if req.Stream {
|
||||
h.handleStreamRequest(c, openaiReq, routeResult)
|
||||
} else {
|
||||
@@ -83,9 +95,7 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStreamRequest 处理非流式请求
|
||||
func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送请求到供应商
|
||||
func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
openaiResp, err := h.client.SendRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
@@ -98,7 +108,6 @@ func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *ope
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为 Anthropic 响应
|
||||
anthropicResp, err := anthropic.ConvertResponse(openaiResp)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
@@ -111,18 +120,14 @@ func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *ope
|
||||
return
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model)
|
||||
}()
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, anthropicResp)
|
||||
}
|
||||
|
||||
// handleStreamRequest 处理流式请求
|
||||
func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送流式请求到供应商
|
||||
func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
@@ -135,24 +140,19 @@ func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 SSE 响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
// 创建流写入器
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
// 创建流式转换器
|
||||
converter := anthropic.NewStreamConverter(
|
||||
fmt.Sprintf("msg_%s", routeResult.Provider.ID),
|
||||
openaiReq.Model,
|
||||
)
|
||||
|
||||
// 流式转发事件
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
fmt.Printf("流错误: %v\n", event.Error)
|
||||
break
|
||||
}
|
||||
|
||||
@@ -160,25 +160,19 @@ func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai
|
||||
break
|
||||
}
|
||||
|
||||
// 解析 OpenAI 流块
|
||||
chunk, err := openai.NewAdapter().ParseStreamChunk(event.Data)
|
||||
if err != nil {
|
||||
fmt.Printf("解析流块失败: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 转换为 Anthropic 事件
|
||||
anthropicEvents, err := converter.ConvertChunk(chunk)
|
||||
if err != nil {
|
||||
fmt.Printf("转换事件失败: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 写入事件
|
||||
for _, ae := range anthropicEvents {
|
||||
eventStr, err := anthropic.SerializeEvent(ae)
|
||||
if err != nil {
|
||||
fmt.Printf("序列化事件失败: %v\n", err)
|
||||
continue
|
||||
}
|
||||
writer.WriteString(eventStr)
|
||||
@@ -186,13 +180,11 @@ func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai
|
||||
}
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model)
|
||||
}()
|
||||
}
|
||||
|
||||
// checkMultimodalContent 检查多模态内容
|
||||
func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest) error {
|
||||
for _, msg := range req.Messages {
|
||||
for _, block := range msg.Content {
|
||||
@@ -204,40 +196,22 @@ func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleError 处理路由错误
|
||||
func (h *AnthropicHandler) handleError(c *gin.Context, err error) {
|
||||
switch err {
|
||||
case router.ErrModelNotFound:
|
||||
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
c.JSON(appErr.HTTPStatus, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "not_found_error",
|
||||
Message: "模型未找到",
|
||||
},
|
||||
})
|
||||
case router.ErrModelDisabled:
|
||||
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "not_found_error",
|
||||
Message: "模型已禁用",
|
||||
},
|
||||
})
|
||||
case router.ErrProviderDisabled:
|
||||
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "not_found_error",
|
||||
Message: "供应商已禁用",
|
||||
},
|
||||
})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "internal_error",
|
||||
Message: "内部错误: " + err.Error(),
|
||||
Message: appErr.Message,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "internal_error",
|
||||
Message: "内部错误: " + err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
290
backend/internal/handler/handler_test.go
Normal file
290
backend/internal/handler/handler_test.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// ============ Mock 实现 ============
|
||||
|
||||
type mockRoutingService struct {
|
||||
result *domain.RouteResult
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error) {
|
||||
return m.result, m.err
|
||||
}
|
||||
|
||||
type mockStatsService struct {
|
||||
err error
|
||||
stats []domain.UsageStats
|
||||
aggrResult []map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *mockStatsService) Record(providerID, modelName string) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
return m.stats, nil
|
||||
}
|
||||
func (m *mockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} {
|
||||
return m.aggrResult
|
||||
}
|
||||
|
||||
type mockProviderService struct {
|
||||
provider *domain.Provider
|
||||
providers []domain.Provider
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
|
||||
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
|
||||
return m.provider, m.err
|
||||
}
|
||||
func (m *mockProviderService) List() ([]domain.Provider, error) { return m.providers, m.err }
|
||||
func (m *mockProviderService) Update(id string, updates map[string]interface{}) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockProviderService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockModelService struct {
|
||||
model *domain.Model
|
||||
models []domain.Model
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockModelService) Create(model *domain.Model) error { return m.err }
|
||||
func (m *mockModelService) Get(id string) (*domain.Model, error) {
|
||||
return m.model, m.err
|
||||
}
|
||||
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
|
||||
return m.models, m.err
|
||||
}
|
||||
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockModelService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockProviderClient struct {
|
||||
resp *openai.ChatCompletionResponse
|
||||
eventChan chan provider.StreamEvent
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderClient) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) {
|
||||
return m.resp, m.err
|
||||
}
|
||||
func (m *mockProviderClient) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan provider.StreamEvent, error) {
|
||||
return m.eventChan, m.err
|
||||
}
|
||||
|
||||
// ============ OpenAI Handler 测试 ============
|
||||
|
||||
func TestOpenAIHandler_HandleChatCompletions_InvalidJSON(t *testing.T) {
|
||||
h := NewOpenAIHandler(nil, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte("invalid")))
|
||||
|
||||
h.HandleChatCompletions(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestOpenAIHandler_HandleChatCompletions_ValidationError(t *testing.T) {
|
||||
h := NewOpenAIHandler(nil, nil, nil)
|
||||
|
||||
// 缺少 model 字段
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"messages": []map[string]string{{"role": "user", "content": "hi"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.HandleChatCompletions(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestOpenAIHandler_HandleChatCompletions_RouteError(t *testing.T) {
|
||||
routingSvc := &mockRoutingService{err: appErrors.ErrModelNotFound}
|
||||
h := NewOpenAIHandler(nil, routingSvc, nil)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"model": "nonexistent",
|
||||
"messages": []map[string]string{{"role": "user", "content": "hi"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.HandleChatCompletions(c)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
// ============ Provider Handler 测试 ============
|
||||
|
||||
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"id": "p1"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateProvider(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_ListProviders(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
providers: []domain.Provider{
|
||||
{ID: "p1", Name: "P1"},
|
||||
{ID: "p2", Name: "P2"},
|
||||
},
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/api/providers", nil)
|
||||
|
||||
h.ListProviders(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result []domain.Provider
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
assert.Len(t, result, 2)
|
||||
}
|
||||
|
||||
func TestProviderHandler_GetProvider(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
provider: &domain.Provider{ID: "p1", Name: "P1", APIKey: "***"},
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("GET", "/api/providers/p1", nil)
|
||||
|
||||
h.GetProvider(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ Model Handler 测试 ============
|
||||
|
||||
func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"id": "m1"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestModelHandler_ListModels(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
models: []domain.Model{
|
||||
{ID: "m1", ModelName: "gpt-4"},
|
||||
{ID: "m2", ModelName: "gpt-3.5"},
|
||||
},
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/api/models", nil)
|
||||
|
||||
h.ListModels(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ Stats Handler 测试 ============
|
||||
|
||||
func TestStatsHandler_GetStats(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
},
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/api/stats", nil)
|
||||
|
||||
h.GetStats(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/api/stats?start_date=invalid", nil)
|
||||
|
||||
h.GetStats(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestStatsHandler_AggregateStats(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
},
|
||||
aggrResult: []map[string]interface{}{
|
||||
{"provider_id": "p1", "request_count": 10},
|
||||
},
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/api/stats/aggregate?group_by=provider", nil)
|
||||
|
||||
h.AggregateStats(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ writeError 测试 ============
|
||||
|
||||
func TestWriteError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
writeError(c, appErrors.ErrModelNotFound)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
func TestFormatValidationErrors(t *testing.T) {
|
||||
errs := map[string]string{
|
||||
"model": "模型名称不能为空",
|
||||
"messages": "消息列表不能为空",
|
||||
}
|
||||
result := formatValidationErrors(errs)
|
||||
require.Contains(t, result, "请求验证失败")
|
||||
require.Contains(t, result, "model")
|
||||
require.Contains(t, result, "messages")
|
||||
}
|
||||
21
backend/internal/handler/middleware/cors.go
Normal file
21
backend/internal/handler/middleware/cors.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS 跨域中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Request-ID")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
40
backend/internal/handler/middleware/logging.go
Normal file
40
backend/internal/handler/middleware/logging.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Logging 日志中间件
|
||||
func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
query := c.Request.URL.RawQuery
|
||||
|
||||
requestID, _ := c.Get(RequestIDKey)
|
||||
logger.Info("请求开始",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", path),
|
||||
zap.String("query", query),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
zap.Any("request_id", requestID),
|
||||
)
|
||||
|
||||
c.Next()
|
||||
|
||||
latency := time.Since(start)
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
logger.Info("请求结束",
|
||||
zap.Int("status", statusCode),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", path),
|
||||
zap.Duration("latency", latency),
|
||||
zap.Int("body_size", c.Writer.Size()),
|
||||
zap.Any("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
130
backend/internal/handler/middleware/middleware_test.go
Normal file
130
backend/internal/handler/middleware/middleware_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func TestRequestID_GeneratesUUID(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(RequestID())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
id, exists := c.Get(RequestIDKey)
|
||||
assert.True(t, exists)
|
||||
assert.NotEmpty(t, id)
|
||||
c.Status(200)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.NotEmpty(t, w.Header().Get("X-Request-ID"))
|
||||
}
|
||||
|
||||
func TestRequestID_UsesExistingHeader(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(RequestID())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
id, _ := c.Get(RequestIDKey)
|
||||
assert.Equal(t, "existing-id-123", id)
|
||||
c.Status(200)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Request-ID", "existing-id-123")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "existing-id-123", w.Header().Get("X-Request-ID"))
|
||||
}
|
||||
|
||||
func TestLogging(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
r := gin.New()
|
||||
r.Use(Logging(logger))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Status(200)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/test?key=value", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestRecovery_NoPanic(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
r := gin.New()
|
||||
r.Use(Recovery(logger))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Status(200)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestRecovery_WithPanic(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
r := gin.New()
|
||||
r.Use(Recovery(logger))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
panic("test panic")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 500, w.Code)
|
||||
}
|
||||
|
||||
func TestCORS_NormalRequest(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(CORS())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Status(200)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
assert.Contains(t, w.Header().Get("Access-Control-Allow-Methods"), "GET")
|
||||
assert.Contains(t, w.Header().Get("Access-Control-Allow-Methods"), "POST")
|
||||
}
|
||||
|
||||
func TestCORS_PreflightRequest(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(CORS())
|
||||
r.OPTIONS("/test", func(c *gin.Context) {
|
||||
c.Status(200)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("OPTIONS", "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 204, w.Code)
|
||||
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
29
backend/internal/handler/middleware/recovery.go
Normal file
29
backend/internal/handler/middleware/recovery.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Recovery 错误恢复中间件
|
||||
func Recovery(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
requestID, _ := c.Get(RequestIDKey)
|
||||
logger.Error("panic recovered",
|
||||
zap.Any("error", err),
|
||||
zap.Any("request_id", requestID),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.Stack("stack"),
|
||||
)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "内部错误",
|
||||
})
|
||||
}
|
||||
}()
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
21
backend/internal/handler/middleware/request_id.go
Normal file
21
backend/internal/handler/middleware/request_id.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const RequestIDKey = "request_id"
|
||||
|
||||
// RequestID 请求 ID 中间件
|
||||
func RequestID() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
requestID := c.GetHeader("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
}
|
||||
c.Set(RequestIDKey, requestID)
|
||||
c.Header("X-Request-ID", requestID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -6,15 +6,20 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ModelHandler 模型管理处理器
|
||||
type ModelHandler struct{}
|
||||
type ModelHandler struct {
|
||||
modelService service.ModelService
|
||||
}
|
||||
|
||||
// NewModelHandler 创建模型处理器
|
||||
func NewModelHandler() *ModelHandler {
|
||||
return &ModelHandler{}
|
||||
func NewModelHandler(modelService service.ModelService) *ModelHandler {
|
||||
return &ModelHandler{modelService: modelService}
|
||||
}
|
||||
|
||||
// CreateModel 创建模型
|
||||
@@ -32,26 +37,21 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 创建模型对象
|
||||
model := &config.Model{
|
||||
model := &domain.Model{
|
||||
ID: req.ID,
|
||||
ProviderID: req.ProviderID,
|
||||
ModelName: req.ModelName,
|
||||
Enabled: true, // 默认启用
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
err := config.CreateModel(model)
|
||||
err := h.modelService.Create(model)
|
||||
if err != nil {
|
||||
if err.Error() == "供应商不存在" {
|
||||
if err == appErrors.ErrProviderNotFound {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "创建模型失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -62,11 +62,9 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
|
||||
func (h *ModelHandler) ListModels(c *gin.Context) {
|
||||
providerID := c.Query("provider_id")
|
||||
|
||||
models, err := config.ListModels(providerID)
|
||||
models, err := h.modelService.List(providerID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询模型失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -77,7 +75,7 @@ func (h *ModelHandler) ListModels(c *gin.Context) {
|
||||
func (h *ModelHandler) GetModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
model, err := config.GetModel(id)
|
||||
model, err := h.modelService.Get(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
@@ -85,9 +83,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询模型失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -106,8 +102,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 更新模型
|
||||
err := config.UpdateModel(id, req)
|
||||
err := h.modelService.Update(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
@@ -115,24 +110,19 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if err.Error() == "供应商不存在" {
|
||||
if err == appErrors.ErrProviderNotFound {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "更新模型失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的模型
|
||||
model, err := config.GetModel(id)
|
||||
model, err := h.modelService.Get(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询更新后的模型失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -143,7 +133,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
||||
func (h *ModelHandler) DeleteModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
err := config.DeleteModel(id)
|
||||
err := h.modelService.Delete(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
@@ -151,9 +141,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "删除模型失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -4,32 +4,36 @@ import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/router"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// OpenAIHandler OpenAI 协议处理器
|
||||
type OpenAIHandler struct {
|
||||
client *provider.Client
|
||||
router *router.Router
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
statsService service.StatsService
|
||||
}
|
||||
|
||||
// NewOpenAIHandler 创建 OpenAI 处理器
|
||||
func NewOpenAIHandler() *OpenAIHandler {
|
||||
func NewOpenAIHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *OpenAIHandler {
|
||||
return &OpenAIHandler{
|
||||
client: provider.NewClient(),
|
||||
router: router.NewRouter(),
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
statsService: statsService,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleChatCompletions 处理 Chat Completions 请求
|
||||
func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
|
||||
// 解析请求
|
||||
var req openai.ChatCompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
|
||||
@@ -41,14 +45,23 @@ func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 路由到供应商
|
||||
routeResult, err := h.router.Route(req.Model)
|
||||
// 请求验证
|
||||
if validationErrors := openai.ValidateRequest(&req); validationErrors != nil {
|
||||
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: formatValidationErrors(validationErrors),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
routeResult, err := h.routingService.Route(req.Model)
|
||||
if err != nil {
|
||||
h.handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 根据是否流式选择处理方式
|
||||
if req.Stream {
|
||||
h.handleStreamRequest(c, &req, routeResult)
|
||||
} else {
|
||||
@@ -56,9 +69,7 @@ func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStreamRequest 处理非流式请求
|
||||
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送请求到供应商
|
||||
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
@@ -70,18 +81,14 @@ func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatC
|
||||
return
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// handleStreamRequest 处理流式请求
|
||||
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送流式请求到供应商
|
||||
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
@@ -93,75 +100,58 @@ func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatComp
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 SSE 响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
// 创建流写入器
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
// 流式转发事件
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
// 流错误,记录日志
|
||||
fmt.Printf("流错误: %v\n", event.Error)
|
||||
break
|
||||
}
|
||||
|
||||
if event.Done {
|
||||
// 流结束
|
||||
writer.WriteString("data: [DONE]\n\n")
|
||||
writer.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
// 写入事件数据
|
||||
writer.WriteString("data: ")
|
||||
writer.Write(event.Data)
|
||||
writer.WriteString("\n\n")
|
||||
writer.Flush()
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
}
|
||||
|
||||
// handleError 处理路由错误
|
||||
func (h *OpenAIHandler) handleError(c *gin.Context, err error) {
|
||||
switch err {
|
||||
case router.ErrModelNotFound:
|
||||
c.JSON(http.StatusNotFound, openai.ErrorResponse{
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
c.JSON(appErr.HTTPStatus, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "模型未找到",
|
||||
Message: appErr.Message,
|
||||
Type: "invalid_request_error",
|
||||
Code: "model_not_found",
|
||||
},
|
||||
})
|
||||
case router.ErrModelDisabled:
|
||||
c.JSON(http.StatusNotFound, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "模型已禁用",
|
||||
Type: "invalid_request_error",
|
||||
Code: "model_disabled",
|
||||
},
|
||||
})
|
||||
case router.ErrProviderDisabled:
|
||||
c.JSON(http.StatusNotFound, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "供应商已禁用",
|
||||
Type: "invalid_request_error",
|
||||
Code: "provider_disabled",
|
||||
},
|
||||
})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "内部错误: " + err.Error(),
|
||||
Type: "internal_error",
|
||||
Code: appErr.Code,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "内部错误: " + err.Error(),
|
||||
Type: "internal_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// formatValidationErrors 将验证错误 map 格式化为字符串
|
||||
func formatValidationErrors(errors map[string]string) string {
|
||||
parts := make([]string, 0, len(errors))
|
||||
for field, msg := range errors {
|
||||
parts = append(parts, fmt.Sprintf("%s: %s", field, msg))
|
||||
}
|
||||
return "请求验证失败: " + strings.Join(parts, "; ")
|
||||
}
|
||||
|
||||
@@ -2,19 +2,25 @@ package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ProviderHandler 供应商管理处理器
|
||||
type ProviderHandler struct{}
|
||||
type ProviderHandler struct {
|
||||
providerService service.ProviderService
|
||||
}
|
||||
|
||||
// NewProviderHandler 创建供应商处理器
|
||||
func NewProviderHandler() *ProviderHandler {
|
||||
return &ProviderHandler{}
|
||||
func NewProviderHandler(providerService service.ProviderService) *ProviderHandler {
|
||||
return &ProviderHandler{providerService: providerService}
|
||||
}
|
||||
|
||||
// CreateProvider 创建供应商
|
||||
@@ -33,43 +39,34 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 创建供应商对象
|
||||
provider := &config.Provider{
|
||||
provider := &domain.Provider{
|
||||
ID: req.ID,
|
||||
Name: req.Name,
|
||||
APIKey: req.APIKey,
|
||||
BaseURL: req.BaseURL,
|
||||
Enabled: true, // 默认启用
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
err := config.CreateProvider(provider)
|
||||
err := h.providerService.Create(provider)
|
||||
if err != nil {
|
||||
// 检查是否是唯一约束错误(ID 重复)
|
||||
if err.Error() == "UNIQUE constraint failed: providers.id" {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "供应商 ID 已存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "创建供应商失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 掩码 API Key 后返回
|
||||
provider.MaskAPIKey()
|
||||
c.JSON(http.StatusCreated, provider)
|
||||
}
|
||||
|
||||
// ListProviders 列出所有供应商
|
||||
func (h *ProviderHandler) ListProviders(c *gin.Context) {
|
||||
providers, err := config.ListProviders()
|
||||
providers, err := h.providerService.List()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询供应商失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -80,7 +77,7 @@ func (h *ProviderHandler) ListProviders(c *gin.Context) {
|
||||
func (h *ProviderHandler) GetProvider(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
provider, err := config.GetProvider(id, true) // 掩码 API Key
|
||||
provider, err := h.providerService.Get(id, true)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
@@ -88,9 +85,7 @@ func (h *ProviderHandler) GetProvider(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询供应商失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -109,8 +104,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 更新供应商
|
||||
err := config.UpdateProvider(id, req)
|
||||
err := h.providerService.Update(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
@@ -118,18 +112,13 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "更新供应商失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的供应商
|
||||
provider, err := config.GetProvider(id, true)
|
||||
provider, err := h.providerService.Get(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询更新后的供应商失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -140,8 +129,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
||||
func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
// 删除供应商(级联删除模型)
|
||||
err := config.DeleteProvider(id)
|
||||
err := h.providerService.Delete(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
@@ -149,19 +137,23 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "删除供应商失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 删除关联的模型
|
||||
models, _ := config.ListModels("")
|
||||
for _, model := range models {
|
||||
if model.ProviderID == id {
|
||||
_ = config.DeleteModel(model.ID)
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// writeError 统一错误响应处理
|
||||
func writeError(c *gin.Context, err error) {
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
c.JSON(appErr.HTTPStatus, gin.H{
|
||||
"error": appErr.Message,
|
||||
"code": appErr.Code,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,20 +6,21 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// StatsHandler 统计处理器
|
||||
type StatsHandler struct{}
|
||||
type StatsHandler struct {
|
||||
statsService service.StatsService
|
||||
}
|
||||
|
||||
// NewStatsHandler 创建统计处理器
|
||||
func NewStatsHandler() *StatsHandler {
|
||||
return &StatsHandler{}
|
||||
func NewStatsHandler(statsService service.StatsService) *StatsHandler {
|
||||
return &StatsHandler{statsService: statsService}
|
||||
}
|
||||
|
||||
// GetStats 查询统计
|
||||
func (h *StatsHandler) GetStats(c *gin.Context) {
|
||||
// 解析查询参数
|
||||
providerID := c.Query("provider_id")
|
||||
modelName := c.Query("model_name")
|
||||
startDateStr := c.Query("start_date")
|
||||
@@ -27,7 +28,6 @@ func (h *StatsHandler) GetStats(c *gin.Context) {
|
||||
|
||||
var startDate, endDate *time.Time
|
||||
|
||||
// 解析日期
|
||||
if startDateStr != "" {
|
||||
t, err := time.Parse("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
@@ -50,8 +50,7 @@ func (h *StatsHandler) GetStats(c *gin.Context) {
|
||||
endDate = &t
|
||||
}
|
||||
|
||||
// 查询统计
|
||||
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
|
||||
stats, err := h.statsService.Get(providerID, modelName, startDate, endDate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询统计失败: " + err.Error(),
|
||||
@@ -64,16 +63,14 @@ func (h *StatsHandler) GetStats(c *gin.Context) {
|
||||
|
||||
// AggregateStats 聚合统计
|
||||
func (h *StatsHandler) AggregateStats(c *gin.Context) {
|
||||
// 解析查询参数
|
||||
providerID := c.Query("provider_id")
|
||||
modelName := c.Query("model_name")
|
||||
startDateStr := c.Query("start_date")
|
||||
endDateStr := c.Query("end_date")
|
||||
groupBy := c.Query("group_by") // "provider", "model", "date"
|
||||
groupBy := c.Query("group_by")
|
||||
|
||||
var startDate, endDate *time.Time
|
||||
|
||||
// 解析日期
|
||||
if startDateStr != "" {
|
||||
t, err := time.Parse("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
@@ -96,8 +93,7 @@ func (h *StatsHandler) AggregateStats(c *gin.Context) {
|
||||
endDate = &t
|
||||
}
|
||||
|
||||
// 查询统计
|
||||
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
|
||||
stats, err := h.statsService.Get(providerID, modelName, startDate, endDate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询统计失败: " + err.Error(),
|
||||
@@ -105,80 +101,6 @@ func (h *StatsHandler) AggregateStats(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 聚合
|
||||
result := h.aggregate(stats, groupBy)
|
||||
|
||||
result := h.statsService.Aggregate(stats, groupBy)
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// aggregate 执行聚合
|
||||
func (h *StatsHandler) aggregate(stats []config.UsageStats, groupBy string) []map[string]interface{} {
|
||||
switch groupBy {
|
||||
case "provider":
|
||||
return h.aggregateByProvider(stats)
|
||||
case "model":
|
||||
return h.aggregateByModel(stats)
|
||||
case "date":
|
||||
return h.aggregateByDate(stats)
|
||||
default:
|
||||
// 默认按供应商聚合
|
||||
return h.aggregateByProvider(stats)
|
||||
}
|
||||
}
|
||||
|
||||
// aggregateByProvider 按供应商聚合
|
||||
func (h *StatsHandler) aggregateByProvider(stats []config.UsageStats) []map[string]interface{} {
|
||||
aggregated := make(map[string]int)
|
||||
for _, stat := range stats {
|
||||
aggregated[stat.ProviderID] += stat.RequestCount
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(aggregated))
|
||||
for providerID, count := range aggregated {
|
||||
result = append(result, map[string]interface{}{
|
||||
"provider_id": providerID,
|
||||
"request_count": count,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// aggregateByModel 按模型聚合
|
||||
func (h *StatsHandler) aggregateByModel(stats []config.UsageStats) []map[string]interface{} {
|
||||
aggregated := make(map[string]int)
|
||||
for _, stat := range stats {
|
||||
key := stat.ProviderID + "/" + stat.ModelName
|
||||
aggregated[key] += stat.RequestCount
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(aggregated))
|
||||
for key, count := range aggregated {
|
||||
result = append(result, map[string]interface{}{
|
||||
"provider_id": key[:len(key)/2],
|
||||
"model_name": key[len(key)/2+1:],
|
||||
"request_count": count,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// aggregateByDate 按日期聚合
|
||||
func (h *StatsHandler) aggregateByDate(stats []config.UsageStats) []map[string]interface{} {
|
||||
aggregated := make(map[string]int)
|
||||
for _, stat := range stats {
|
||||
key := stat.Date.Format("2006-01-02")
|
||||
aggregated[key] += stat.RequestCount
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(aggregated))
|
||||
for date, count := range aggregated {
|
||||
result = append(result, map[string]interface{}{
|
||||
"date": date,
|
||||
"request_count": count,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user