feat: 初始化 AI Gateway 项目
实现支持 OpenAI 和 Anthropic 双协议的统一大模型 API 网关 MVP 版本,包含: - OpenAI 和 Anthropic 协议代理 - 供应商和模型管理 - 用量统计 - 前端配置界面
This commit is contained in:
243
backend/internal/handler/anthropic_handler.go
Normal file
243
backend/internal/handler/anthropic_handler.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/protocol/anthropic"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/router"
|
||||
)
|
||||
|
||||
// AnthropicHandler Anthropic 协议处理器
|
||||
type AnthropicHandler struct {
|
||||
client *provider.Client
|
||||
router *router.Router
|
||||
}
|
||||
|
||||
// NewAnthropicHandler 创建 Anthropic 处理器
|
||||
func NewAnthropicHandler() *AnthropicHandler {
|
||||
return &AnthropicHandler{
|
||||
client: provider.NewClient(),
|
||||
router: router.NewRouter(),
|
||||
}
|
||||
}
|
||||
|
||||
// HandleMessages 处理 Messages 请求
|
||||
func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
|
||||
// 解析 Anthropic 请求
|
||||
var req anthropic.MessagesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: "无效的请求格式: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查多模态内容
|
||||
if err := h.checkMultimodalContent(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为 OpenAI 请求
|
||||
openaiReq, err := anthropic.ConvertRequest(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: "请求转换失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 路由到供应商
|
||||
routeResult, err := h.router.Route(openaiReq.Model)
|
||||
if err != nil {
|
||||
h.handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 根据是否流式选择处理方式
|
||||
if req.Stream {
|
||||
h.handleStreamRequest(c, openaiReq, routeResult)
|
||||
} else {
|
||||
h.handleNonStreamRequest(c, openaiReq, routeResult)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStreamRequest 处理非流式请求
|
||||
func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送请求到供应商
|
||||
openaiResp, err := h.client.SendRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为 Anthropic 响应
|
||||
anthropicResp, err := anthropic.ConvertResponse(openaiResp)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: "响应转换失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
|
||||
}()
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, anthropicResp)
|
||||
}
|
||||
|
||||
// handleStreamRequest 处理流式请求
|
||||
func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送流式请求到供应商
|
||||
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 SSE 响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
// 创建流写入器
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
// 创建流式转换器
|
||||
converter := anthropic.NewStreamConverter(
|
||||
fmt.Sprintf("msg_%s", routeResult.Provider.ID),
|
||||
openaiReq.Model,
|
||||
)
|
||||
|
||||
// 流式转发事件
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
fmt.Printf("流错误: %v\n", event.Error)
|
||||
break
|
||||
}
|
||||
|
||||
if event.Done {
|
||||
break
|
||||
}
|
||||
|
||||
// 解析 OpenAI 流块
|
||||
chunk, err := openai.NewAdapter().ParseStreamChunk(event.Data)
|
||||
if err != nil {
|
||||
fmt.Printf("解析流块失败: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 转换为 Anthropic 事件
|
||||
anthropicEvents, err := converter.ConvertChunk(chunk)
|
||||
if err != nil {
|
||||
fmt.Printf("转换事件失败: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 写入事件
|
||||
for _, ae := range anthropicEvents {
|
||||
eventStr, err := anthropic.SerializeEvent(ae)
|
||||
if err != nil {
|
||||
fmt.Printf("序列化事件失败: %v\n", err)
|
||||
continue
|
||||
}
|
||||
writer.WriteString(eventStr)
|
||||
writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
|
||||
}()
|
||||
}
|
||||
|
||||
// checkMultimodalContent 检查多模态内容
|
||||
func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest) error {
|
||||
for _, msg := range req.Messages {
|
||||
for _, block := range msg.Content {
|
||||
if block.Type == "image" {
|
||||
return fmt.Errorf("MVP 不支持多模态内容(图片)")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleError 处理路由错误
|
||||
func (h *AnthropicHandler) handleError(c *gin.Context, err error) {
|
||||
switch err {
|
||||
case router.ErrModelNotFound:
|
||||
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "not_found_error",
|
||||
Message: "模型未找到",
|
||||
},
|
||||
})
|
||||
case router.ErrModelDisabled:
|
||||
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "not_found_error",
|
||||
Message: "模型已禁用",
|
||||
},
|
||||
})
|
||||
case router.ErrProviderDisabled:
|
||||
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "not_found_error",
|
||||
Message: "供应商已禁用",
|
||||
},
|
||||
})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "internal_error",
|
||||
Message: "内部错误: " + err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
161
backend/internal/handler/model_handler.go
Normal file
161
backend/internal/handler/model_handler.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
)
|
||||
|
||||
// ModelHandler 模型管理处理器
|
||||
type ModelHandler struct{}
|
||||
|
||||
// NewModelHandler 创建模型处理器
|
||||
func NewModelHandler() *ModelHandler {
|
||||
return &ModelHandler{}
|
||||
}
|
||||
|
||||
// CreateModel 创建模型
|
||||
func (h *ModelHandler) CreateModel(c *gin.Context) {
|
||||
var req struct {
|
||||
ID string `json:"id" binding:"required"`
|
||||
ProviderID string `json:"provider_id" binding:"required"`
|
||||
ModelName string `json:"model_name" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "缺少必需字段: id, provider_id, model_name",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 创建模型对象
|
||||
model := &config.Model{
|
||||
ID: req.ID,
|
||||
ProviderID: req.ProviderID,
|
||||
ModelName: req.ModelName,
|
||||
Enabled: true, // 默认启用
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
err := config.CreateModel(model)
|
||||
if err != nil {
|
||||
if err.Error() == "供应商不存在" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "创建模型失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, model)
|
||||
}
|
||||
|
||||
// ListModels 列出模型
|
||||
func (h *ModelHandler) ListModels(c *gin.Context) {
|
||||
providerID := c.Query("provider_id")
|
||||
|
||||
models, err := config.ListModels(providerID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询模型失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models)
|
||||
}
|
||||
|
||||
// GetModel 获取模型
|
||||
func (h *ModelHandler) GetModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
model, err := config.GetModel(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询模型失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model)
|
||||
}
|
||||
|
||||
// UpdateModel 更新模型
|
||||
func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
var req map[string]interface{}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "无效的请求格式",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新模型
|
||||
err := config.UpdateModel(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err.Error() == "供应商不存在" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "更新模型失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的模型
|
||||
model, err := config.GetModel(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询更新后的模型失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model)
|
||||
}
|
||||
|
||||
// DeleteModel 删除模型
|
||||
func (h *ModelHandler) DeleteModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
err := config.DeleteModel(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "删除模型失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
167
backend/internal/handler/openai_handler.go
Normal file
167
backend/internal/handler/openai_handler.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/router"
|
||||
)
|
||||
|
||||
// OpenAIHandler OpenAI 协议处理器
|
||||
type OpenAIHandler struct {
|
||||
client *provider.Client
|
||||
router *router.Router
|
||||
}
|
||||
|
||||
// NewOpenAIHandler 创建 OpenAI 处理器
|
||||
func NewOpenAIHandler() *OpenAIHandler {
|
||||
return &OpenAIHandler{
|
||||
client: provider.NewClient(),
|
||||
router: router.NewRouter(),
|
||||
}
|
||||
}
|
||||
|
||||
// HandleChatCompletions 处理 Chat Completions 请求
|
||||
func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
|
||||
// 解析请求
|
||||
var req openai.ChatCompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "无效的请求格式: " + err.Error(),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 路由到供应商
|
||||
routeResult, err := h.router.Route(req.Model)
|
||||
if err != nil {
|
||||
h.handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 根据是否流式选择处理方式
|
||||
if req.Stream {
|
||||
h.handleStreamRequest(c, &req, routeResult)
|
||||
} else {
|
||||
h.handleNonStreamRequest(c, &req, routeResult)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStreamRequest 处理非流式请求
|
||||
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送请求到供应商
|
||||
resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
Type: "api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// handleStreamRequest 处理流式请求
|
||||
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送流式请求到供应商
|
||||
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
Type: "api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 SSE 响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
// 创建流写入器
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
// 流式转发事件
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
// 流错误,记录日志
|
||||
fmt.Printf("流错误: %v\n", event.Error)
|
||||
break
|
||||
}
|
||||
|
||||
if event.Done {
|
||||
// 流结束
|
||||
writer.WriteString("data: [DONE]\n\n")
|
||||
writer.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
// 写入事件数据
|
||||
writer.WriteString("data: ")
|
||||
writer.Write(event.Data)
|
||||
writer.WriteString("\n\n")
|
||||
writer.Flush()
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
}
|
||||
|
||||
// handleError 处理路由错误
|
||||
func (h *OpenAIHandler) handleError(c *gin.Context, err error) {
|
||||
switch err {
|
||||
case router.ErrModelNotFound:
|
||||
c.JSON(http.StatusNotFound, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "模型未找到",
|
||||
Type: "invalid_request_error",
|
||||
Code: "model_not_found",
|
||||
},
|
||||
})
|
||||
case router.ErrModelDisabled:
|
||||
c.JSON(http.StatusNotFound, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "模型已禁用",
|
||||
Type: "invalid_request_error",
|
||||
Code: "model_disabled",
|
||||
},
|
||||
})
|
||||
case router.ErrProviderDisabled:
|
||||
c.JSON(http.StatusNotFound, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "供应商已禁用",
|
||||
Type: "invalid_request_error",
|
||||
Code: "provider_disabled",
|
||||
},
|
||||
})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "内部错误: " + err.Error(),
|
||||
Type: "internal_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
167
backend/internal/handler/provider_handler.go
Normal file
167
backend/internal/handler/provider_handler.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
)
|
||||
|
||||
// ProviderHandler 供应商管理处理器
|
||||
type ProviderHandler struct{}
|
||||
|
||||
// NewProviderHandler 创建供应商处理器
|
||||
func NewProviderHandler() *ProviderHandler {
|
||||
return &ProviderHandler{}
|
||||
}
|
||||
|
||||
// CreateProvider 创建供应商
|
||||
func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
var req struct {
|
||||
ID string `json:"id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "缺少必需字段: id, name, api_key, base_url",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 创建供应商对象
|
||||
provider := &config.Provider{
|
||||
ID: req.ID,
|
||||
Name: req.Name,
|
||||
APIKey: req.APIKey,
|
||||
BaseURL: req.BaseURL,
|
||||
Enabled: true, // 默认启用
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
err := config.CreateProvider(provider)
|
||||
if err != nil {
|
||||
// 检查是否是唯一约束错误(ID 重复)
|
||||
if err.Error() == "UNIQUE constraint failed: providers.id" {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "供应商 ID 已存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "创建供应商失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 掩码 API Key 后返回
|
||||
provider.MaskAPIKey()
|
||||
c.JSON(http.StatusCreated, provider)
|
||||
}
|
||||
|
||||
// ListProviders 列出所有供应商
|
||||
func (h *ProviderHandler) ListProviders(c *gin.Context) {
|
||||
providers, err := config.ListProviders()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询供应商失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, providers)
|
||||
}
|
||||
|
||||
// GetProvider 获取供应商
|
||||
func (h *ProviderHandler) GetProvider(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
provider, err := config.GetProvider(id, true) // 掩码 API Key
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询供应商失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, provider)
|
||||
}
|
||||
|
||||
// UpdateProvider 更新供应商
|
||||
func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
var req map[string]interface{}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "无效的请求格式",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新供应商
|
||||
err := config.UpdateProvider(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "更新供应商失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的供应商
|
||||
provider, err := config.GetProvider(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询更新后的供应商失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, provider)
|
||||
}
|
||||
|
||||
// DeleteProvider 删除供应商
|
||||
func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
// 删除供应商(级联删除模型)
|
||||
err := config.DeleteProvider(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "删除供应商失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 删除关联的模型
|
||||
models, _ := config.ListModels("")
|
||||
for _, model := range models {
|
||||
if model.ProviderID == id {
|
||||
_ = config.DeleteModel(model.ID)
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
184
backend/internal/handler/stats_handler.go
Normal file
184
backend/internal/handler/stats_handler.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
)
|
||||
|
||||
// StatsHandler 统计处理器
|
||||
type StatsHandler struct{}
|
||||
|
||||
// NewStatsHandler 创建统计处理器
|
||||
func NewStatsHandler() *StatsHandler {
|
||||
return &StatsHandler{}
|
||||
}
|
||||
|
||||
// GetStats 查询统计
|
||||
func (h *StatsHandler) GetStats(c *gin.Context) {
|
||||
// 解析查询参数
|
||||
providerID := c.Query("provider_id")
|
||||
modelName := c.Query("model_name")
|
||||
startDateStr := c.Query("start_date")
|
||||
endDateStr := c.Query("end_date")
|
||||
|
||||
var startDate, endDate *time.Time
|
||||
|
||||
// 解析日期
|
||||
if startDateStr != "" {
|
||||
t, err := time.Parse("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "无效的 start_date 格式,应为 YYYY-MM-DD",
|
||||
})
|
||||
return
|
||||
}
|
||||
startDate = &t
|
||||
}
|
||||
|
||||
if endDateStr != "" {
|
||||
t, err := time.Parse("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "无效的 end_date 格式,应为 YYYY-MM-DD",
|
||||
})
|
||||
return
|
||||
}
|
||||
endDate = &t
|
||||
}
|
||||
|
||||
// 查询统计
|
||||
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询统计失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// AggregateStats 聚合统计
|
||||
func (h *StatsHandler) AggregateStats(c *gin.Context) {
|
||||
// 解析查询参数
|
||||
providerID := c.Query("provider_id")
|
||||
modelName := c.Query("model_name")
|
||||
startDateStr := c.Query("start_date")
|
||||
endDateStr := c.Query("end_date")
|
||||
groupBy := c.Query("group_by") // "provider", "model", "date"
|
||||
|
||||
var startDate, endDate *time.Time
|
||||
|
||||
// 解析日期
|
||||
if startDateStr != "" {
|
||||
t, err := time.Parse("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "无效的 start_date 格式,应为 YYYY-MM-DD",
|
||||
})
|
||||
return
|
||||
}
|
||||
startDate = &t
|
||||
}
|
||||
|
||||
if endDateStr != "" {
|
||||
t, err := time.Parse("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "无效的 end_date 格式,应为 YYYY-MM-DD",
|
||||
})
|
||||
return
|
||||
}
|
||||
endDate = &t
|
||||
}
|
||||
|
||||
// 查询统计
|
||||
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询统计失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 聚合
|
||||
result := h.aggregate(stats, groupBy)
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// aggregate 执行聚合
|
||||
func (h *StatsHandler) aggregate(stats []config.UsageStats, groupBy string) []map[string]interface{} {
|
||||
switch groupBy {
|
||||
case "provider":
|
||||
return h.aggregateByProvider(stats)
|
||||
case "model":
|
||||
return h.aggregateByModel(stats)
|
||||
case "date":
|
||||
return h.aggregateByDate(stats)
|
||||
default:
|
||||
// 默认按供应商聚合
|
||||
return h.aggregateByProvider(stats)
|
||||
}
|
||||
}
|
||||
|
||||
// aggregateByProvider 按供应商聚合
|
||||
func (h *StatsHandler) aggregateByProvider(stats []config.UsageStats) []map[string]interface{} {
|
||||
aggregated := make(map[string]int)
|
||||
for _, stat := range stats {
|
||||
aggregated[stat.ProviderID] += stat.RequestCount
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(aggregated))
|
||||
for providerID, count := range aggregated {
|
||||
result = append(result, map[string]interface{}{
|
||||
"provider_id": providerID,
|
||||
"request_count": count,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// aggregateByModel 按模型聚合
|
||||
func (h *StatsHandler) aggregateByModel(stats []config.UsageStats) []map[string]interface{} {
|
||||
aggregated := make(map[string]int)
|
||||
for _, stat := range stats {
|
||||
key := stat.ProviderID + "/" + stat.ModelName
|
||||
aggregated[key] += stat.RequestCount
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(aggregated))
|
||||
for key, count := range aggregated {
|
||||
result = append(result, map[string]interface{}{
|
||||
"provider_id": key[:len(key)/2],
|
||||
"model_name": key[len(key)/2+1:],
|
||||
"request_count": count,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// aggregateByDate 按日期聚合
|
||||
func (h *StatsHandler) aggregateByDate(stats []config.UsageStats) []map[string]interface{} {
|
||||
aggregated := make(map[string]int)
|
||||
for _, stat := range stats {
|
||||
key := stat.Date.Format("2006-01-02")
|
||||
aggregated[key] += stat.RequestCount
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(aggregated))
|
||||
for date, count := range aggregated {
|
||||
result = append(result, map[string]interface{}{
|
||||
"date": date,
|
||||
"request_count": count,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
Reference in New Issue
Block a user