1
0

feat: 初始化 AI Gateway 项目

实现支持 OpenAI 和 Anthropic 双协议的统一大模型 API 网关 MVP 版本,包含:
- OpenAI 和 Anthropic 协议代理
- 供应商和模型管理
- 用量统计
- 前端配置界面
This commit is contained in:
2026-04-15 16:53:28 +08:00
commit 915b004924
53 changed files with 5662 additions and 0 deletions

View File

@@ -0,0 +1,243 @@
package handler
import (
"bufio"
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"nex/backend/internal/config"
"nex/backend/internal/protocol/anthropic"
"nex/backend/internal/protocol/openai"
"nex/backend/internal/provider"
"nex/backend/internal/router"
)
// AnthropicHandler Anthropic 协议处理器
type AnthropicHandler struct {
client *provider.Client
router *router.Router
}
// NewAnthropicHandler 创建 Anthropic 处理器
func NewAnthropicHandler() *AnthropicHandler {
return &AnthropicHandler{
client: provider.NewClient(),
router: router.NewRouter(),
}
}
// HandleMessages 处理 Messages 请求
func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
// 解析 Anthropic 请求
var req anthropic.MessagesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "invalid_request_error",
Message: "无效的请求格式: " + err.Error(),
},
})
return
}
// 检查多模态内容
if err := h.checkMultimodalContent(&req); err != nil {
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "invalid_request_error",
Message: err.Error(),
},
})
return
}
// 转换为 OpenAI 请求
openaiReq, err := anthropic.ConvertRequest(&req)
if err != nil {
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "invalid_request_error",
Message: "请求转换失败: " + err.Error(),
},
})
return
}
// 路由到供应商
routeResult, err := h.router.Route(openaiReq.Model)
if err != nil {
h.handleError(c, err)
return
}
// 根据是否流式选择处理方式
if req.Stream {
h.handleStreamRequest(c, openaiReq, routeResult)
} else {
h.handleNonStreamRequest(c, openaiReq, routeResult)
}
}
// handleNonStreamRequest 处理非流式请求
func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送请求到供应商
openaiResp, err := h.client.SendRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "api_error",
Message: "供应商请求失败: " + err.Error(),
},
})
return
}
// 转换为 Anthropic 响应
anthropicResp, err := anthropic.ConvertResponse(openaiResp)
if err != nil {
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "api_error",
Message: "响应转换失败: " + err.Error(),
},
})
return
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
}()
// 返回响应
c.JSON(http.StatusOK, anthropicResp)
}
// handleStreamRequest 处理流式请求
func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送流式请求到供应商
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "api_error",
Message: "供应商请求失败: " + err.Error(),
},
})
return
}
// 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
// 创建流写入器
writer := bufio.NewWriter(c.Writer)
// 创建流式转换器
converter := anthropic.NewStreamConverter(
fmt.Sprintf("msg_%s", routeResult.Provider.ID),
openaiReq.Model,
)
// 流式转发事件
for event := range eventChan {
if event.Error != nil {
fmt.Printf("流错误: %v\n", event.Error)
break
}
if event.Done {
break
}
// 解析 OpenAI 流块
chunk, err := openai.NewAdapter().ParseStreamChunk(event.Data)
if err != nil {
fmt.Printf("解析流块失败: %v\n", err)
continue
}
// 转换为 Anthropic 事件
anthropicEvents, err := converter.ConvertChunk(chunk)
if err != nil {
fmt.Printf("转换事件失败: %v\n", err)
continue
}
// 写入事件
for _, ae := range anthropicEvents {
eventStr, err := anthropic.SerializeEvent(ae)
if err != nil {
fmt.Printf("序列化事件失败: %v\n", err)
continue
}
writer.WriteString(eventStr)
writer.Flush()
}
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
}()
}
// checkMultimodalContent 检查多模态内容
func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest) error {
for _, msg := range req.Messages {
for _, block := range msg.Content {
if block.Type == "image" {
return fmt.Errorf("MVP 不支持多模态内容(图片)")
}
}
}
return nil
}
// handleError 处理路由错误
func (h *AnthropicHandler) handleError(c *gin.Context, err error) {
switch err {
case router.ErrModelNotFound:
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "not_found_error",
Message: "模型未找到",
},
})
case router.ErrModelDisabled:
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "not_found_error",
Message: "模型已禁用",
},
})
case router.ErrProviderDisabled:
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "not_found_error",
Message: "供应商已禁用",
},
})
default:
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "internal_error",
Message: "内部错误: " + err.Error(),
},
})
}
}

View File

@@ -0,0 +1,161 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"nex/backend/internal/config"
)
// ModelHandler 模型管理处理器
type ModelHandler struct{}
// NewModelHandler 创建模型处理器
func NewModelHandler() *ModelHandler {
return &ModelHandler{}
}
// CreateModel 创建模型
func (h *ModelHandler) CreateModel(c *gin.Context) {
var req struct {
ID string `json:"id" binding:"required"`
ProviderID string `json:"provider_id" binding:"required"`
ModelName string `json:"model_name" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "缺少必需字段: id, provider_id, model_name",
})
return
}
// 创建模型对象
model := &config.Model{
ID: req.ID,
ProviderID: req.ProviderID,
ModelName: req.ModelName,
Enabled: true, // 默认启用
}
// 保存到数据库
err := config.CreateModel(model)
if err != nil {
if err.Error() == "供应商不存在" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "创建模型失败: " + err.Error(),
})
return
}
c.JSON(http.StatusCreated, model)
}
// ListModels 列出模型
func (h *ModelHandler) ListModels(c *gin.Context) {
providerID := c.Query("provider_id")
models, err := config.ListModels(providerID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询模型失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, models)
}
// GetModel 获取模型
func (h *ModelHandler) GetModel(c *gin.Context) {
id := c.Param("id")
model, err := config.GetModel(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询模型失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, model)
}
// UpdateModel 更新模型
func (h *ModelHandler) UpdateModel(c *gin.Context) {
id := c.Param("id")
var req map[string]interface{}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的请求格式",
})
return
}
// 更新模型
err := config.UpdateModel(id, req)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到",
})
return
}
if err.Error() == "供应商不存在" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "更新模型失败: " + err.Error(),
})
return
}
// 返回更新后的模型
model, err := config.GetModel(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询更新后的模型失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, model)
}
// DeleteModel 删除模型
func (h *ModelHandler) DeleteModel(c *gin.Context) {
id := c.Param("id")
err := config.DeleteModel(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "删除模型失败: " + err.Error(),
})
return
}
c.Status(http.StatusNoContent)
}

View File

@@ -0,0 +1,167 @@
package handler
import (
"bufio"
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"nex/backend/internal/config"
"nex/backend/internal/protocol/openai"
"nex/backend/internal/provider"
"nex/backend/internal/router"
)
// OpenAIHandler OpenAI 协议处理器
type OpenAIHandler struct {
client *provider.Client
router *router.Router
}
// NewOpenAIHandler 创建 OpenAI 处理器
func NewOpenAIHandler() *OpenAIHandler {
return &OpenAIHandler{
client: provider.NewClient(),
router: router.NewRouter(),
}
}
// HandleChatCompletions 处理 Chat Completions 请求
func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
// 解析请求
var req openai.ChatCompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "无效的请求格式: " + err.Error(),
Type: "invalid_request_error",
},
})
return
}
// 路由到供应商
routeResult, err := h.router.Route(req.Model)
if err != nil {
h.handleError(c, err)
return
}
// 根据是否流式选择处理方式
if req.Stream {
h.handleStreamRequest(c, &req, routeResult)
} else {
h.handleNonStreamRequest(c, &req, routeResult)
}
}
// handleNonStreamRequest 处理非流式请求
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送请求到供应商
resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "供应商请求失败: " + err.Error(),
Type: "api_error",
},
})
return
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
}()
// 返回响应
c.JSON(http.StatusOK, resp)
}
// handleStreamRequest 处理流式请求
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送流式请求到供应商
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "供应商请求失败: " + err.Error(),
Type: "api_error",
},
})
return
}
// 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
// 创建流写入器
writer := bufio.NewWriter(c.Writer)
// 流式转发事件
for event := range eventChan {
if event.Error != nil {
// 流错误,记录日志
fmt.Printf("流错误: %v\n", event.Error)
break
}
if event.Done {
// 流结束
writer.WriteString("data: [DONE]\n\n")
writer.Flush()
break
}
// 写入事件数据
writer.WriteString("data: ")
writer.Write(event.Data)
writer.WriteString("\n\n")
writer.Flush()
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
}()
}
// handleError 处理路由错误
func (h *OpenAIHandler) handleError(c *gin.Context, err error) {
switch err {
case router.ErrModelNotFound:
c.JSON(http.StatusNotFound, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "模型未找到",
Type: "invalid_request_error",
Code: "model_not_found",
},
})
case router.ErrModelDisabled:
c.JSON(http.StatusNotFound, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "模型已禁用",
Type: "invalid_request_error",
Code: "model_disabled",
},
})
case router.ErrProviderDisabled:
c.JSON(http.StatusNotFound, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "供应商已禁用",
Type: "invalid_request_error",
Code: "provider_disabled",
},
})
default:
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "内部错误: " + err.Error(),
Type: "internal_error",
},
})
}
}

View File

@@ -0,0 +1,167 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"nex/backend/internal/config"
)
// ProviderHandler 供应商管理处理器
type ProviderHandler struct{}
// NewProviderHandler 创建供应商处理器
func NewProviderHandler() *ProviderHandler {
return &ProviderHandler{}
}
// CreateProvider 创建供应商
func (h *ProviderHandler) CreateProvider(c *gin.Context) {
var req struct {
ID string `json:"id" binding:"required"`
Name string `json:"name" binding:"required"`
APIKey string `json:"api_key" binding:"required"`
BaseURL string `json:"base_url" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "缺少必需字段: id, name, api_key, base_url",
})
return
}
// 创建供应商对象
provider := &config.Provider{
ID: req.ID,
Name: req.Name,
APIKey: req.APIKey,
BaseURL: req.BaseURL,
Enabled: true, // 默认启用
}
// 保存到数据库
err := config.CreateProvider(provider)
if err != nil {
// 检查是否是唯一约束错误ID 重复)
if err.Error() == "UNIQUE constraint failed: providers.id" {
c.JSON(http.StatusConflict, gin.H{
"error": "供应商 ID 已存在",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "创建供应商失败: " + err.Error(),
})
return
}
// 掩码 API Key 后返回
provider.MaskAPIKey()
c.JSON(http.StatusCreated, provider)
}
// ListProviders 列出所有供应商
func (h *ProviderHandler) ListProviders(c *gin.Context) {
providers, err := config.ListProviders()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询供应商失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, providers)
}
// GetProvider 获取供应商
func (h *ProviderHandler) GetProvider(c *gin.Context) {
id := c.Param("id")
provider, err := config.GetProvider(id, true) // 掩码 API Key
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询供应商失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, provider)
}
// UpdateProvider 更新供应商
func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
id := c.Param("id")
var req map[string]interface{}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的请求格式",
})
return
}
// 更新供应商
err := config.UpdateProvider(id, req)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "更新供应商失败: " + err.Error(),
})
return
}
// 返回更新后的供应商
provider, err := config.GetProvider(id, true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询更新后的供应商失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, provider)
}
// DeleteProvider 删除供应商
func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
id := c.Param("id")
// 删除供应商(级联删除模型)
err := config.DeleteProvider(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "删除供应商失败: " + err.Error(),
})
return
}
// 删除关联的模型
models, _ := config.ListModels("")
for _, model := range models {
if model.ProviderID == id {
_ = config.DeleteModel(model.ID)
}
}
c.Status(http.StatusNoContent)
}

View File

@@ -0,0 +1,184 @@
package handler
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"nex/backend/internal/config"
)
// StatsHandler 统计处理器
type StatsHandler struct{}
// NewStatsHandler 创建统计处理器
func NewStatsHandler() *StatsHandler {
return &StatsHandler{}
}
// GetStats 查询统计
func (h *StatsHandler) GetStats(c *gin.Context) {
// 解析查询参数
providerID := c.Query("provider_id")
modelName := c.Query("model_name")
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
var startDate, endDate *time.Time
// 解析日期
if startDateStr != "" {
t, err := time.Parse("2006-01-02", startDateStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的 start_date 格式,应为 YYYY-MM-DD",
})
return
}
startDate = &t
}
if endDateStr != "" {
t, err := time.Parse("2006-01-02", endDateStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的 end_date 格式,应为 YYYY-MM-DD",
})
return
}
endDate = &t
}
// 查询统计
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询统计失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, stats)
}
// AggregateStats 聚合统计
func (h *StatsHandler) AggregateStats(c *gin.Context) {
// 解析查询参数
providerID := c.Query("provider_id")
modelName := c.Query("model_name")
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
groupBy := c.Query("group_by") // "provider", "model", "date"
var startDate, endDate *time.Time
// 解析日期
if startDateStr != "" {
t, err := time.Parse("2006-01-02", startDateStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的 start_date 格式,应为 YYYY-MM-DD",
})
return
}
startDate = &t
}
if endDateStr != "" {
t, err := time.Parse("2006-01-02", endDateStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的 end_date 格式,应为 YYYY-MM-DD",
})
return
}
endDate = &t
}
// 查询统计
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询统计失败: " + err.Error(),
})
return
}
// 聚合
result := h.aggregate(stats, groupBy)
c.JSON(http.StatusOK, result)
}
// aggregate 执行聚合
func (h *StatsHandler) aggregate(stats []config.UsageStats, groupBy string) []map[string]interface{} {
switch groupBy {
case "provider":
return h.aggregateByProvider(stats)
case "model":
return h.aggregateByModel(stats)
case "date":
return h.aggregateByDate(stats)
default:
// 默认按供应商聚合
return h.aggregateByProvider(stats)
}
}
// aggregateByProvider 按供应商聚合
func (h *StatsHandler) aggregateByProvider(stats []config.UsageStats) []map[string]interface{} {
aggregated := make(map[string]int)
for _, stat := range stats {
aggregated[stat.ProviderID] += stat.RequestCount
}
result := make([]map[string]interface{}, 0, len(aggregated))
for providerID, count := range aggregated {
result = append(result, map[string]interface{}{
"provider_id": providerID,
"request_count": count,
})
}
return result
}
// aggregateByModel 按模型聚合
func (h *StatsHandler) aggregateByModel(stats []config.UsageStats) []map[string]interface{} {
aggregated := make(map[string]int)
for _, stat := range stats {
key := stat.ProviderID + "/" + stat.ModelName
aggregated[key] += stat.RequestCount
}
result := make([]map[string]interface{}, 0, len(aggregated))
for key, count := range aggregated {
result = append(result, map[string]interface{}{
"provider_id": key[:len(key)/2],
"model_name": key[len(key)/2+1:],
"request_count": count,
})
}
return result
}
// aggregateByDate 按日期聚合
func (h *StatsHandler) aggregateByDate(stats []config.UsageStats) []map[string]interface{} {
aggregated := make(map[string]int)
for _, stat := range stats {
key := stat.Date.Format("2006-01-02")
aggregated[key] += stat.RequestCount
}
result := make([]map[string]interface{}, 0, len(aggregated))
for date, count := range aggregated {
result = append(result, map[string]interface{}{
"date": date,
"request_count": count,
})
}
return result
}