1
0

feat: 初始化 AI Gateway 项目

实现支持 OpenAI 和 Anthropic 双协议的统一大模型 API 网关 MVP 版本,包含:
- OpenAI 和 Anthropic 协议代理
- 供应商和模型管理
- 用量统计
- 前端配置界面
This commit is contained in:
2026-04-15 16:53:28 +08:00
commit 915b004924
53 changed files with 5662 additions and 0 deletions

View File

@@ -0,0 +1,32 @@
package config
import (
"os"
"path/filepath"
)
// GetConfigDir 获取配置目录路径(~/.nex/
func GetConfigDir() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}
configDir := filepath.Join(homeDir, ".nex")
// 确保目录存在
if err := os.MkdirAll(configDir, 0755); err != nil {
return "", err
}
return configDir, nil
}
// GetDBPath 获取数据库文件路径
func GetDBPath() (string, error) {
configDir, err := GetConfigDir()
if err != nil {
return "", err
}
return filepath.Join(configDir, "config.db"), nil
}

View File

@@ -0,0 +1,58 @@
package config
import (
"fmt"
"log"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
var db *gorm.DB
// InitDB 初始化数据库连接并创建表
func InitDB() error {
dbPath, err := GetDBPath()
if err != nil {
return fmt.Errorf("获取数据库路径失败: %w", err)
}
// 打开数据库连接
db, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
})
if err != nil {
return fmt.Errorf("连接数据库失败: %w", err)
}
// 启用 WAL 模式以提升并发性能
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
log.Printf("警告: 启用 WAL 模式失败: %v", err)
}
// 自动迁移表结构
if err := db.AutoMigrate(&Provider{}, &Model{}, &UsageStats{}); err != nil {
return fmt.Errorf("创建表失败: %w", err)
}
log.Printf("数据库初始化成功: %s", dbPath)
return nil
}
// GetDB 获取数据库连接
func GetDB() *gorm.DB {
return db
}
// CloseDB 关闭数据库连接
func CloseDB() error {
if db != nil {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
return nil
}

View File

@@ -0,0 +1,119 @@
package config
import (
"errors"
"time"
"gorm.io/gorm"
)
// CreateModel 创建模型
func CreateModel(model *Model) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
// 验证供应商是否存在
var provider Provider
err := db.First(&provider, "id = ?", model.ProviderID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("供应商不存在")
}
return err
}
model.CreatedAt = time.Now()
return db.Create(model).Error
}
// GetModel 获取模型
func GetModel(id string) (*Model, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var model Model
err := db.First(&model, "id = ?", id).Error
if err != nil {
return nil, err
}
return &model, nil
}
// ListModels 列出模型
func ListModels(providerID string) ([]Model, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var models []Model
var err error
if providerID != "" {
err = db.Where("provider_id = ?", providerID).Find(&models).Error
} else {
err = db.Find(&models).Error
}
if err != nil {
return nil, err
}
return models, nil
}
// UpdateModel 更新模型
func UpdateModel(id string, updates map[string]interface{}) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
// 如果更新了 provider_id验证新供应商是否存在
if providerID, ok := updates["provider_id"].(string); ok {
var provider Provider
err := db.First(&provider, "id = ?", providerID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("供应商不存在")
}
return err
}
}
result := db.Model(&Model{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
// DeleteModel 删除模型
func DeleteModel(id string) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
result := db.Delete(&Model{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}

View File

@@ -0,0 +1,57 @@
package config
import (
"time"
)
// Provider 供应商模型
type Provider struct {
ID string `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"`
APIKey string `gorm:"not null" json:"api_key"`
BaseURL string `gorm:"not null" json:"base_url"`
Enabled bool `gorm:"default:true" json:"enabled"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
}
// Model 模型配置
type Model struct {
ID string `gorm:"primaryKey" json:"id"`
ProviderID string `gorm:"not null;index" json:"provider_id"`
ModelName string `gorm:"not null;index" json:"model_name"`
Enabled bool `gorm:"default:true" json:"enabled"`
CreatedAt time.Time `json:"created_at"`
}
// UsageStats 用量统计
type UsageStats struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
ProviderID string `gorm:"not null;index" json:"provider_id"`
ModelName string `gorm:"not null;index" json:"model_name"`
RequestCount int `gorm:"default:0" json:"request_count"`
Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"`
}
// TableName 指定表名
func (Provider) TableName() string {
return "providers"
}
func (Model) TableName() string {
return "models"
}
func (UsageStats) TableName() string {
return "usage_stats"
}
// MaskAPIKey 掩码 API Key仅显示最后 4 个字符)
func (p *Provider) MaskAPIKey() {
if len(p.APIKey) > 4 {
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
} else {
p.APIKey = "***"
}
}

View File

@@ -0,0 +1,102 @@
package config
import (
"errors"
"time"
"gorm.io/gorm"
)
// CreateProvider 创建供应商
func CreateProvider(provider *Provider) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
provider.CreatedAt = time.Now()
provider.UpdatedAt = time.Now()
return db.Create(provider).Error
}
// GetProvider 获取供应商
func GetProvider(id string, maskKey bool) (*Provider, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var provider Provider
err := db.First(&provider, "id = ?", id).Error
if err != nil {
return nil, err
}
if maskKey {
provider.MaskAPIKey()
}
return &provider, nil
}
// ListProviders 列出所有供应商
func ListProviders() ([]Provider, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var providers []Provider
err := db.Find(&providers).Error
if err != nil {
return nil, err
}
// 掩码所有 API Key
for i := range providers {
providers[i].MaskAPIKey()
}
return providers, nil
}
// UpdateProvider 更新供应商
func UpdateProvider(id string, updates map[string]interface{}) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
updates["updated_at"] = time.Now()
result := db.Model(&Provider{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
// DeleteProvider 删除供应商
func DeleteProvider(id string) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
result := db.Delete(&Provider{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}

View File

@@ -0,0 +1,79 @@
package config
import (
"errors"
"time"
"gorm.io/gorm"
)
// RecordRequest 记录请求统计
func RecordRequest(providerID, modelName string) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
today := time.Now().Format("2006-01-02")
todayTime, _ := time.Parse("2006-01-02", today)
// 使用事务确保并发安全
return db.Transaction(func(tx *gorm.DB) error {
var stats UsageStats
// 查找或创建统计记录
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
providerID, modelName, todayTime).
First(&stats).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// 创建新记录
stats = UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: 1,
Date: todayTime,
}
return tx.Create(&stats).Error
} else if err != nil {
return err
}
// 更新计数
return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error
})
}
// GetStats 查询统计
func GetStats(providerID, modelName string, startDate, endDate *time.Time) ([]UsageStats, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var stats []UsageStats
query := db.Model(&UsageStats{})
if providerID != "" {
query = query.Where("provider_id = ?", providerID)
}
if modelName != "" {
query = query.Where("model_name = ?", modelName)
}
if startDate != nil {
query = query.Where("date >= ?", startDate)
}
if endDate != nil {
query = query.Where("date <= ?", endDate)
}
err := query.Order("date DESC").Find(&stats).Error
if err != nil {
return nil, err
}
return stats, nil
}

View File

@@ -0,0 +1,243 @@
package handler
import (
"bufio"
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"nex/backend/internal/config"
"nex/backend/internal/protocol/anthropic"
"nex/backend/internal/protocol/openai"
"nex/backend/internal/provider"
"nex/backend/internal/router"
)
// AnthropicHandler Anthropic 协议处理器
type AnthropicHandler struct {
client *provider.Client
router *router.Router
}
// NewAnthropicHandler 创建 Anthropic 处理器
func NewAnthropicHandler() *AnthropicHandler {
return &AnthropicHandler{
client: provider.NewClient(),
router: router.NewRouter(),
}
}
// HandleMessages 处理 Messages 请求
func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
// 解析 Anthropic 请求
var req anthropic.MessagesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "invalid_request_error",
Message: "无效的请求格式: " + err.Error(),
},
})
return
}
// 检查多模态内容
if err := h.checkMultimodalContent(&req); err != nil {
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "invalid_request_error",
Message: err.Error(),
},
})
return
}
// 转换为 OpenAI 请求
openaiReq, err := anthropic.ConvertRequest(&req)
if err != nil {
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "invalid_request_error",
Message: "请求转换失败: " + err.Error(),
},
})
return
}
// 路由到供应商
routeResult, err := h.router.Route(openaiReq.Model)
if err != nil {
h.handleError(c, err)
return
}
// 根据是否流式选择处理方式
if req.Stream {
h.handleStreamRequest(c, openaiReq, routeResult)
} else {
h.handleNonStreamRequest(c, openaiReq, routeResult)
}
}
// handleNonStreamRequest 处理非流式请求
func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送请求到供应商
openaiResp, err := h.client.SendRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "api_error",
Message: "供应商请求失败: " + err.Error(),
},
})
return
}
// 转换为 Anthropic 响应
anthropicResp, err := anthropic.ConvertResponse(openaiResp)
if err != nil {
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "api_error",
Message: "响应转换失败: " + err.Error(),
},
})
return
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
}()
// 返回响应
c.JSON(http.StatusOK, anthropicResp)
}
// handleStreamRequest 处理流式请求
func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送流式请求到供应商
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "api_error",
Message: "供应商请求失败: " + err.Error(),
},
})
return
}
// 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
// 创建流写入器
writer := bufio.NewWriter(c.Writer)
// 创建流式转换器
converter := anthropic.NewStreamConverter(
fmt.Sprintf("msg_%s", routeResult.Provider.ID),
openaiReq.Model,
)
// 流式转发事件
for event := range eventChan {
if event.Error != nil {
fmt.Printf("流错误: %v\n", event.Error)
break
}
if event.Done {
break
}
// 解析 OpenAI 流块
chunk, err := openai.NewAdapter().ParseStreamChunk(event.Data)
if err != nil {
fmt.Printf("解析流块失败: %v\n", err)
continue
}
// 转换为 Anthropic 事件
anthropicEvents, err := converter.ConvertChunk(chunk)
if err != nil {
fmt.Printf("转换事件失败: %v\n", err)
continue
}
// 写入事件
for _, ae := range anthropicEvents {
eventStr, err := anthropic.SerializeEvent(ae)
if err != nil {
fmt.Printf("序列化事件失败: %v\n", err)
continue
}
writer.WriteString(eventStr)
writer.Flush()
}
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
}()
}
// checkMultimodalContent 检查多模态内容
func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest) error {
for _, msg := range req.Messages {
for _, block := range msg.Content {
if block.Type == "image" {
return fmt.Errorf("MVP 不支持多模态内容(图片)")
}
}
}
return nil
}
// handleError 处理路由错误
func (h *AnthropicHandler) handleError(c *gin.Context, err error) {
switch err {
case router.ErrModelNotFound:
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "not_found_error",
Message: "模型未找到",
},
})
case router.ErrModelDisabled:
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "not_found_error",
Message: "模型已禁用",
},
})
case router.ErrProviderDisabled:
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "not_found_error",
Message: "供应商已禁用",
},
})
default:
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "internal_error",
Message: "内部错误: " + err.Error(),
},
})
}
}

View File

@@ -0,0 +1,161 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"nex/backend/internal/config"
)
// ModelHandler 模型管理处理器
type ModelHandler struct{}
// NewModelHandler 创建模型处理器
func NewModelHandler() *ModelHandler {
return &ModelHandler{}
}
// CreateModel 创建模型
func (h *ModelHandler) CreateModel(c *gin.Context) {
var req struct {
ID string `json:"id" binding:"required"`
ProviderID string `json:"provider_id" binding:"required"`
ModelName string `json:"model_name" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "缺少必需字段: id, provider_id, model_name",
})
return
}
// 创建模型对象
model := &config.Model{
ID: req.ID,
ProviderID: req.ProviderID,
ModelName: req.ModelName,
Enabled: true, // 默认启用
}
// 保存到数据库
err := config.CreateModel(model)
if err != nil {
if err.Error() == "供应商不存在" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "创建模型失败: " + err.Error(),
})
return
}
c.JSON(http.StatusCreated, model)
}
// ListModels 列出模型
func (h *ModelHandler) ListModels(c *gin.Context) {
providerID := c.Query("provider_id")
models, err := config.ListModels(providerID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询模型失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, models)
}
// GetModel 获取模型
func (h *ModelHandler) GetModel(c *gin.Context) {
id := c.Param("id")
model, err := config.GetModel(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询模型失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, model)
}
// UpdateModel 更新模型
func (h *ModelHandler) UpdateModel(c *gin.Context) {
id := c.Param("id")
var req map[string]interface{}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的请求格式",
})
return
}
// 更新模型
err := config.UpdateModel(id, req)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到",
})
return
}
if err.Error() == "供应商不存在" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "更新模型失败: " + err.Error(),
})
return
}
// 返回更新后的模型
model, err := config.GetModel(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询更新后的模型失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, model)
}
// DeleteModel 删除模型
func (h *ModelHandler) DeleteModel(c *gin.Context) {
id := c.Param("id")
err := config.DeleteModel(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "删除模型失败: " + err.Error(),
})
return
}
c.Status(http.StatusNoContent)
}

View File

@@ -0,0 +1,167 @@
package handler
import (
"bufio"
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"nex/backend/internal/config"
"nex/backend/internal/protocol/openai"
"nex/backend/internal/provider"
"nex/backend/internal/router"
)
// OpenAIHandler OpenAI 协议处理器
type OpenAIHandler struct {
client *provider.Client
router *router.Router
}
// NewOpenAIHandler 创建 OpenAI 处理器
func NewOpenAIHandler() *OpenAIHandler {
return &OpenAIHandler{
client: provider.NewClient(),
router: router.NewRouter(),
}
}
// HandleChatCompletions 处理 Chat Completions 请求
func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
// 解析请求
var req openai.ChatCompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "无效的请求格式: " + err.Error(),
Type: "invalid_request_error",
},
})
return
}
// 路由到供应商
routeResult, err := h.router.Route(req.Model)
if err != nil {
h.handleError(c, err)
return
}
// 根据是否流式选择处理方式
if req.Stream {
h.handleStreamRequest(c, &req, routeResult)
} else {
h.handleNonStreamRequest(c, &req, routeResult)
}
}
// handleNonStreamRequest 处理非流式请求
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送请求到供应商
resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "供应商请求失败: " + err.Error(),
Type: "api_error",
},
})
return
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
}()
// 返回响应
c.JSON(http.StatusOK, resp)
}
// handleStreamRequest 处理流式请求
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送流式请求到供应商
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "供应商请求失败: " + err.Error(),
Type: "api_error",
},
})
return
}
// 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
// 创建流写入器
writer := bufio.NewWriter(c.Writer)
// 流式转发事件
for event := range eventChan {
if event.Error != nil {
// 流错误,记录日志
fmt.Printf("流错误: %v\n", event.Error)
break
}
if event.Done {
// 流结束
writer.WriteString("data: [DONE]\n\n")
writer.Flush()
break
}
// 写入事件数据
writer.WriteString("data: ")
writer.Write(event.Data)
writer.WriteString("\n\n")
writer.Flush()
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
}()
}
// handleError 处理路由错误
func (h *OpenAIHandler) handleError(c *gin.Context, err error) {
switch err {
case router.ErrModelNotFound:
c.JSON(http.StatusNotFound, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "模型未找到",
Type: "invalid_request_error",
Code: "model_not_found",
},
})
case router.ErrModelDisabled:
c.JSON(http.StatusNotFound, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "模型已禁用",
Type: "invalid_request_error",
Code: "model_disabled",
},
})
case router.ErrProviderDisabled:
c.JSON(http.StatusNotFound, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "供应商已禁用",
Type: "invalid_request_error",
Code: "provider_disabled",
},
})
default:
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "内部错误: " + err.Error(),
Type: "internal_error",
},
})
}
}

View File

@@ -0,0 +1,167 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"nex/backend/internal/config"
)
// ProviderHandler 供应商管理处理器
type ProviderHandler struct{}
// NewProviderHandler 创建供应商处理器
func NewProviderHandler() *ProviderHandler {
return &ProviderHandler{}
}
// CreateProvider 创建供应商
func (h *ProviderHandler) CreateProvider(c *gin.Context) {
var req struct {
ID string `json:"id" binding:"required"`
Name string `json:"name" binding:"required"`
APIKey string `json:"api_key" binding:"required"`
BaseURL string `json:"base_url" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "缺少必需字段: id, name, api_key, base_url",
})
return
}
// 创建供应商对象
provider := &config.Provider{
ID: req.ID,
Name: req.Name,
APIKey: req.APIKey,
BaseURL: req.BaseURL,
Enabled: true, // 默认启用
}
// 保存到数据库
err := config.CreateProvider(provider)
if err != nil {
// 检查是否是唯一约束错误ID 重复)
if err.Error() == "UNIQUE constraint failed: providers.id" {
c.JSON(http.StatusConflict, gin.H{
"error": "供应商 ID 已存在",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "创建供应商失败: " + err.Error(),
})
return
}
// 掩码 API Key 后返回
provider.MaskAPIKey()
c.JSON(http.StatusCreated, provider)
}
// ListProviders 列出所有供应商
func (h *ProviderHandler) ListProviders(c *gin.Context) {
providers, err := config.ListProviders()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询供应商失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, providers)
}
// GetProvider 获取供应商
func (h *ProviderHandler) GetProvider(c *gin.Context) {
id := c.Param("id")
provider, err := config.GetProvider(id, true) // 掩码 API Key
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询供应商失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, provider)
}
// UpdateProvider 更新供应商
func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
id := c.Param("id")
var req map[string]interface{}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的请求格式",
})
return
}
// 更新供应商
err := config.UpdateProvider(id, req)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "更新供应商失败: " + err.Error(),
})
return
}
// 返回更新后的供应商
provider, err := config.GetProvider(id, true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询更新后的供应商失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, provider)
}
// DeleteProvider 删除供应商
func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
id := c.Param("id")
// 删除供应商(级联删除模型)
err := config.DeleteProvider(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "删除供应商失败: " + err.Error(),
})
return
}
// 删除关联的模型
models, _ := config.ListModels("")
for _, model := range models {
if model.ProviderID == id {
_ = config.DeleteModel(model.ID)
}
}
c.Status(http.StatusNoContent)
}

View File

@@ -0,0 +1,184 @@
package handler
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"nex/backend/internal/config"
)
// StatsHandler 统计处理器
type StatsHandler struct{}
// NewStatsHandler 创建统计处理器
func NewStatsHandler() *StatsHandler {
return &StatsHandler{}
}
// GetStats 查询统计
func (h *StatsHandler) GetStats(c *gin.Context) {
// 解析查询参数
providerID := c.Query("provider_id")
modelName := c.Query("model_name")
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
var startDate, endDate *time.Time
// 解析日期
if startDateStr != "" {
t, err := time.Parse("2006-01-02", startDateStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的 start_date 格式,应为 YYYY-MM-DD",
})
return
}
startDate = &t
}
if endDateStr != "" {
t, err := time.Parse("2006-01-02", endDateStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的 end_date 格式,应为 YYYY-MM-DD",
})
return
}
endDate = &t
}
// 查询统计
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询统计失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, stats)
}
// AggregateStats 聚合统计
func (h *StatsHandler) AggregateStats(c *gin.Context) {
// 解析查询参数
providerID := c.Query("provider_id")
modelName := c.Query("model_name")
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
groupBy := c.Query("group_by") // "provider", "model", "date"
var startDate, endDate *time.Time
// 解析日期
if startDateStr != "" {
t, err := time.Parse("2006-01-02", startDateStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的 start_date 格式,应为 YYYY-MM-DD",
})
return
}
startDate = &t
}
if endDateStr != "" {
t, err := time.Parse("2006-01-02", endDateStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的 end_date 格式,应为 YYYY-MM-DD",
})
return
}
endDate = &t
}
// 查询统计
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询统计失败: " + err.Error(),
})
return
}
// 聚合
result := h.aggregate(stats, groupBy)
c.JSON(http.StatusOK, result)
}
// aggregate 执行聚合
func (h *StatsHandler) aggregate(stats []config.UsageStats, groupBy string) []map[string]interface{} {
switch groupBy {
case "provider":
return h.aggregateByProvider(stats)
case "model":
return h.aggregateByModel(stats)
case "date":
return h.aggregateByDate(stats)
default:
// 默认按供应商聚合
return h.aggregateByProvider(stats)
}
}
// aggregateByProvider 按供应商聚合
func (h *StatsHandler) aggregateByProvider(stats []config.UsageStats) []map[string]interface{} {
aggregated := make(map[string]int)
for _, stat := range stats {
aggregated[stat.ProviderID] += stat.RequestCount
}
result := make([]map[string]interface{}, 0, len(aggregated))
for providerID, count := range aggregated {
result = append(result, map[string]interface{}{
"provider_id": providerID,
"request_count": count,
})
}
return result
}
// aggregateByModel 按模型聚合
func (h *StatsHandler) aggregateByModel(stats []config.UsageStats) []map[string]interface{} {
aggregated := make(map[string]int)
for _, stat := range stats {
key := stat.ProviderID + "/" + stat.ModelName
aggregated[key] += stat.RequestCount
}
result := make([]map[string]interface{}, 0, len(aggregated))
for key, count := range aggregated {
result = append(result, map[string]interface{}{
"provider_id": key[:len(key)/2],
"model_name": key[len(key)/2+1:],
"request_count": count,
})
}
return result
}
// aggregateByDate 按日期聚合
func (h *StatsHandler) aggregateByDate(stats []config.UsageStats) []map[string]interface{} {
aggregated := make(map[string]int)
for _, stat := range stats {
key := stat.Date.Format("2006-01-02")
aggregated[key] += stat.RequestCount
}
result := make([]map[string]interface{}, 0, len(aggregated))
for date, count := range aggregated {
result = append(result, map[string]interface{}{
"date": date,
"request_count": count,
})
}
return result
}

View File

@@ -0,0 +1,234 @@
package anthropic
import (
"encoding/json"
"fmt"
"nex/backend/internal/protocol/openai"
)
// ConvertRequest 将 Anthropic 请求转换为 OpenAI 请求
func ConvertRequest(anthropicReq *MessagesRequest) (*openai.ChatCompletionRequest, error) {
openaiReq := &openai.ChatCompletionRequest{
Model: anthropicReq.Model,
Temperature: anthropicReq.Temperature,
TopP: anthropicReq.TopP,
Stream: anthropicReq.Stream,
}
// 处理 max_tokensAnthropic 要求必须有,默认 4096
if anthropicReq.MaxTokens > 0 {
openaiReq.MaxTokens = &anthropicReq.MaxTokens
} else {
defaultMax := 4096
openaiReq.MaxTokens = &defaultMax
}
// 处理 stop_sequences
if len(anthropicReq.StopSequences) > 0 {
openaiReq.Stop = anthropicReq.StopSequences
}
// 转换 system 消息
messages := make([]openai.Message, 0)
if anthropicReq.System != "" {
messages = append(messages, openai.Message{
Role: "system",
Content: anthropicReq.System,
})
}
// 转换 messages
for _, msg := range anthropicReq.Messages {
openaiMsg, err := convertMessage(msg)
if err != nil {
return nil, err
}
messages = append(messages, openaiMsg...)
}
openaiReq.Messages = messages
// 转换 tools
if len(anthropicReq.Tools) > 0 {
openaiReq.Tools = make([]openai.Tool, len(anthropicReq.Tools))
for i, tool := range anthropicReq.Tools {
openaiReq.Tools[i] = openai.Tool{
Type: "function",
Function: openai.FunctionDefinition{
Name: tool.Name,
Description: tool.Description,
Parameters: tool.InputSchema,
},
}
}
}
// 转换 tool_choice
if anthropicReq.ToolChoice != nil {
toolChoice, err := convertToolChoice(anthropicReq.ToolChoice)
if err != nil {
return nil, err
}
openaiReq.ToolChoice = toolChoice
}
return openaiReq, nil
}
// ConvertResponse 将 OpenAI 响应转换为 Anthropic 响应
func ConvertResponse(openaiResp *openai.ChatCompletionResponse) (*MessagesResponse, error) {
anthropicResp := &MessagesResponse{
ID: openaiResp.ID,
Type: "message",
Role: "assistant",
Model: openaiResp.Model,
Usage: Usage{
InputTokens: openaiResp.Usage.PromptTokens,
OutputTokens: openaiResp.Usage.CompletionTokens,
},
}
// 转换 content
if len(openaiResp.Choices) > 0 {
choice := openaiResp.Choices[0]
content := make([]ContentBlock, 0)
if choice.Message != nil {
// 文本内容
if choice.Message.Content != "" {
if str, ok := choice.Message.Content.(string); ok && str != "" {
content = append(content, ContentBlock{
Type: "text",
Text: str,
})
}
}
// Tool calls
if len(choice.Message.ToolCalls) > 0 {
for _, tc := range choice.Message.ToolCalls {
// 解析 arguments JSON
var input interface{}
if err := json.Unmarshal([]byte(tc.Function.Arguments), &input); err != nil {
return nil, fmt.Errorf("解析 tool_call arguments 失败: %w", err)
}
content = append(content, ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: input,
})
}
}
}
anthropicResp.Content = content
// 转换 finish_reason
switch choice.FinishReason {
case "stop":
anthropicResp.StopReason = "end_turn"
case "tool_calls":
anthropicResp.StopReason = "tool_use"
case "length":
anthropicResp.StopReason = "max_tokens"
}
}
return anthropicResp, nil
}
// convertMessage 转换单条消息
func convertMessage(msg AnthropicMessage) ([]openai.Message, error) {
var messages []openai.Message
// 处理 content
for _, block := range msg.Content {
switch block.Type {
case "text":
// 文本内容
messages = append(messages, openai.Message{
Role: msg.Role,
Content: block.Text,
})
case "tool_result":
// 工具结果
content := ""
if str, ok := block.Content.(string); ok {
content = str
} else {
// 如果是数组或其他类型,序列化为 JSON
bytes, err := json.Marshal(block.Content)
if err != nil {
return nil, fmt.Errorf("序列化 tool_result 内容失败: %w", err)
}
content = string(bytes)
}
messages = append(messages, openai.Message{
Role: "tool",
Content: content,
ToolCallID: block.ToolUseID,
})
case "image":
// MVP 不支持多模态
return nil, fmt.Errorf("MVP 不支持多模态内容(图片)")
default:
return nil, fmt.Errorf("未知的内容块类型: %s", block.Type)
}
}
// 如果没有 content创建空消息不应该发生
if len(messages) == 0 {
messages = append(messages, openai.Message{
Role: msg.Role,
Content: "",
})
}
return messages, nil
}
// convertToolChoice 转换工具选择
func convertToolChoice(choice interface{}) (interface{}, error) {
// 如果是字符串
if str, ok := choice.(string); ok {
// "auto" 或 "any" 都映射为 "auto"
if str == "auto" || str == "any" {
return "auto", nil
}
return nil, fmt.Errorf("无效的 tool_choice 字符串: %s", str)
}
// 如果是对象
if obj, ok := choice.(map[string]interface{}); ok {
choiceType, ok := obj["type"].(string)
if !ok {
return nil, fmt.Errorf("tool_choice 对象缺少 type 字段")
}
switch choiceType {
case "auto", "any":
return "auto", nil
case "tool":
name, ok := obj["name"].(string)
if !ok {
return nil, fmt.Errorf("tool_choice type=tool 缺少 name 字段")
}
return map[string]interface{}{
"type": "function",
"function": map[string]string{
"name": name,
},
}, nil
default:
return nil, fmt.Errorf("无效的 tool_choice type: %s", choiceType)
}
}
return nil, fmt.Errorf("tool_choice 格式无效")
}

View File

@@ -0,0 +1,164 @@
package anthropic
import (
"encoding/json"
"fmt"
"nex/backend/internal/protocol/openai"
)
// StreamConverter 流式转换器
type StreamConverter struct {
messageID string
model string
index int // 当前 content block index
toolCallArgs map[int]string // 缓存每个 tool_call 的 arguments
sentStart bool // 是否已发送 message_start
sentBlockStart map[int]bool // 每个 index 是否已发送 content_block_start
}
// NewStreamConverter 创建流式转换器
func NewStreamConverter(messageID, model string) *StreamConverter {
return &StreamConverter{
messageID: messageID,
model: model,
index: 0,
toolCallArgs: make(map[int]string),
sentStart: false,
sentBlockStart: make(map[int]bool),
}
}
// ConvertChunk 转换 OpenAI 流块为 Anthropic 事件
func (c *StreamConverter) ConvertChunk(chunk *openai.StreamChunk) ([]StreamEvent, error) {
var events []StreamEvent
// 发送 message_start仅一次
if !c.sentStart {
events = append(events, StreamEvent{
Type: "message_start",
Message: &MessagesResponse{
ID: c.messageID,
Type: "message",
Role: "assistant",
Model: c.model,
Content: []ContentBlock{},
Usage: Usage{
InputTokens: 0,
OutputTokens: 0,
},
},
})
c.sentStart = true
}
// 处理每个 choice
for _, choice := range chunk.Choices {
// 处理 content delta
if choice.Delta.Content != "" {
// 发送 content_block_start如果还没发送
if !c.sentBlockStart[c.index] {
events = append(events, StreamEvent{
Type: "content_block_start",
Index: c.index,
ContentBlock: &ContentBlock{
Type: "text",
},
})
c.sentBlockStart[c.index] = true
}
// 发送 text delta
events = append(events, StreamEvent{
Type: "content_block_delta",
Index: c.index,
Delta: &Delta{
Type: "text_delta",
Text: choice.Delta.Content,
},
})
}
// 处理 tool_calls delta
if len(choice.Delta.ToolCalls) > 0 {
for _, tc := range choice.Delta.ToolCalls {
// 确定 tool_call index
toolIndex := c.index + len(c.toolCallArgs)
// 发送 content_block_start如果还没发送
if !c.sentBlockStart[toolIndex] {
events = append(events, StreamEvent{
Type: "content_block_start",
Index: toolIndex,
ContentBlock: &ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
},
})
c.sentBlockStart[toolIndex] = true
c.toolCallArgs[toolIndex] = ""
}
// 缓存 arguments
c.toolCallArgs[toolIndex] += tc.Function.Arguments
// 发送 input delta
events = append(events, StreamEvent{
Type: "content_block_delta",
Index: toolIndex,
Delta: &Delta{
Type: "input_json_delta",
Input: tc.Function.Arguments,
},
})
}
}
// 处理 finish_reason
if choice.FinishReason != "" {
// 发送 content_block_stop
for idx := range c.sentBlockStart {
events = append(events, StreamEvent{
Type: "content_block_stop",
Index: idx,
})
}
// 转换 stop_reason
stopReason := ""
switch choice.FinishReason {
case "stop":
stopReason = "end_turn"
case "tool_calls":
stopReason = "tool_use"
case "length":
stopReason = "max_tokens"
}
// 发送 message_delta
events = append(events, StreamEvent{
Type: "message_delta",
Delta: &Delta{
StopReason: stopReason,
},
})
// 发送 message_stop
events = append(events, StreamEvent{
Type: "message_stop",
})
}
}
return events, nil
}
// SerializeEvent 序列化事件为 SSE 格式
func SerializeEvent(event StreamEvent) (string, error) {
bytes, err := json.Marshal(event)
if err != nil {
return "", err
}
return fmt.Sprintf("event: %s\ndata: %s\n\n", event.Type, string(bytes)), nil
}

View File

@@ -0,0 +1,118 @@
package anthropic
import "encoding/json"
// MessagesRequest Anthropic Messages API 请求结构
type MessagesRequest struct {
Model string `json:"model"`
Messages []AnthropicMessage `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []AnthropicTool `json:"tools,omitempty"`
ToolChoice interface{} `json:"tool_choice,omitempty"` // 可以是字符串或对象
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// AnthropicMessage Anthropic 消息结构
type AnthropicMessage struct {
Role string `json:"role"`
Content []ContentBlock `json:"content"`
}
// ContentBlock 内容块
type ContentBlock struct {
Type string `json:"type"` // "text", "image", "tool_use", "tool_result"
Text string `json:"text,omitempty"`
Input interface{} `json:"input,omitempty"` // 用于 tool_use
// tool_use 字段
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
// tool_result 字段
ToolUseID string `json:"tool_use_id,omitempty"`
Content interface{} `json:"content,omitempty"` // 可以是字符串或数组
// 多模态字段MVP 不支持)
Source interface{} `json:"source,omitempty"` // 用于 image
}
// AnthropicTool Anthropic 工具定义
type AnthropicTool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema map[string]interface{} `json:"input_schema"`
}
// ToolChoice 工具选择
type ToolChoice struct {
Type string `json:"type"` // "auto", "any", "tool"
Name string `json:"name,omitempty"` // 当 type="tool" 时使用
}
// MessagesResponse Anthropic Messages API 响应结构
type MessagesResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Content []ContentBlock `json:"content"`
Model string `json:"model"`
StopReason string `json:"stop_reason,omitempty"` // "end_turn", "max_tokens", "stop_sequence", "tool_use"
StopSequence string `json:"stop_sequence,omitempty"`
Usage Usage `json:"usage"`
}
// Usage 使用统计
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// StreamEvent 流式事件
type StreamEvent struct {
Type string `json:"type"`
Message *MessagesResponse `json:"message,omitempty"` // 用于 message_start
Index int `json:"index,omitempty"` // 用于 content_block_* 事件
ContentBlock *ContentBlock `json:"content_block,omitempty"` // 用于 content_block_start
Delta *Delta `json:"delta,omitempty"` // 用于 content_block_delta
}
// Delta 增量内容
type Delta struct {
Type string `json:"type,omitempty"` // "text_delta", "input_json_delta"
Text string `json:"text,omitempty"`
Input string `json:"input,omitempty"` // 用于 tool_use 的部分 JSON
StopReason string `json:"stop_reason,omitempty"` // 用于 message_delta
Usage *Usage `json:"usage,omitempty"` // 用于 message_delta
}
// ErrorResponse Anthropic 错误响应
type ErrorResponse struct {
Type string `json:"type"` // "error"
Error ErrorDetail `json:"error"`
}
// ErrorDetail 错误详情
type ErrorDetail struct {
Type string `json:"type"` // "invalid_request_error", "authentication_error", etc.
Message string `json:"message"`
}
// ParseInputJSON 解析 tool_use 的 input从 JSON 字符串转为 map
func (cb *ContentBlock) ParseInputJSON() (map[string]interface{}, error) {
if str, ok := cb.Input.(string); ok {
var result map[string]interface{}
err := json.Unmarshal([]byte(str), &result)
return result, err
}
// 如果已经是对象,直接返回
if obj, ok := cb.Input.(map[string]interface{}); ok {
return obj, nil
}
return nil, json.Unmarshal([]byte{}, nil) // 返回错误
}

View File

@@ -0,0 +1,86 @@
package openai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
)
// Adapter OpenAI 协议适配器(透传)
type Adapter struct{}
// NewAdapter 创建 OpenAI 适配器
func NewAdapter() *Adapter {
return &Adapter{}
}
// PrepareRequest 准备发送给供应商的请求(透传)
func (a *Adapter) PrepareRequest(req *ChatCompletionRequest, apiKey, baseURL string) (*http.Request, error) {
// 序列化请求体
body, err := json.Marshal(req)
if err != nil {
return nil, err
}
// 调试日志:打印请求体
fmt.Printf("[DEBUG] 请求Body: %s\n", string(body))
// 创建 HTTP 请求
// baseURL 已包含版本路径(如 /v1 或 /v4只需添加端点路径
httpReq, err := http.NewRequest("POST", baseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return nil, err
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
return httpReq, nil
}
// ParseResponse 解析供应商响应(透传)
func (a *Adapter) ParseResponse(resp *http.Response) (*ChatCompletionResponse, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var result ChatCompletionResponse
err = json.Unmarshal(body, &result)
if err != nil {
return nil, err
}
return &result, nil
}
// ParseErrorResponse 解析错误响应
func (a *Adapter) ParseErrorResponse(resp *http.Response) (*ErrorResponse, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var result ErrorResponse
err = json.Unmarshal(body, &result)
if err != nil {
return nil, err
}
return &result, nil
}
// ParseStreamChunk 解析流式响应块
func (a *Adapter) ParseStreamChunk(data []byte) (*StreamChunk, error) {
var chunk StreamChunk
err := json.Unmarshal(data, &chunk)
if err != nil {
return nil, err
}
return &chunk, nil
}

View File

@@ -0,0 +1,131 @@
package openai
import "encoding/json"
// ChatCompletionRequest OpenAI Chat Completions API 请求结构
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Temperature *float64 `json:"temperature,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
Stop interface{} `json:"stop,omitempty"` // 可以是字符串或字符串数组
N *int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice interface{} `json:"tool_choice,omitempty"` // 可以是字符串或对象
User string `json:"user,omitempty"`
}
// Message OpenAI 消息结构
type Message struct {
Role string `json:"role"`
Content interface{} `json:"content"` // 可以是字符串或数组多模态MVP不支持
Name string `json:"name,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` // 用于 role="tool" 的消息
}
// Tool OpenAI 工具定义
type Tool struct {
Type string `json:"type"` // 目前只有 "function"
Function FunctionDefinition `json:"function"`
}
// FunctionDefinition 函数定义
type FunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters map[string]interface{} `json:"parameters,omitempty"`
}
// ToolCall 工具调用
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"` // "function"
Function FunctionCall `json:"function"`
}
// FunctionCall 函数调用
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"` // JSON 字符串
}
// ChatCompletionResponse OpenAI Chat Completions API 响应结构
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
}
// Choice 响应选项
type Choice struct {
Index int `json:"index"`
Message *Message `json:"message,omitempty"`
Delta *Delta `json:"delta,omitempty"` // 用于流式响应
FinishReason string `json:"finish_reason"`
}
// Delta 流式响应增量
type Delta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
// Usage Token 使用统计
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// StreamChunk 流式响应块
type StreamChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []StreamChoice `json:"choices"`
}
// StreamChoice 流式响应选项
type StreamChoice struct {
Index int `json:"index"`
Delta Delta `json:"delta"`
FinishReason string `json:"finish_reason,omitempty"`
}
// ErrorResponse OpenAI 错误响应
type ErrorResponse struct {
Error ErrorDetail `json:"error"`
}
// ErrorDetail 错误详情
type ErrorDetail struct {
Message string `json:"message"`
Type string `json:"type,omitempty"`
Code string `json:"code,omitempty"`
}
// ParseToolCallArguments 解析 tool_call 的 arguments从 JSON 字符串转为 map
func (tc *ToolCall) ParseToolCallArguments() (map[string]interface{}, error) {
var args map[string]interface{}
err := json.Unmarshal([]byte(tc.Function.Arguments), &args)
return args, err
}
// SerializeToolCallArguments 序列化 tool_call 的 arguments从 map 转为 JSON 字符串)
func SerializeToolCallArguments(args map[string]interface{}) (string, error) {
bytes, err := json.Marshal(args)
if err != nil {
return "", err
}
return string(bytes), nil
}

View File

@@ -0,0 +1,177 @@
package provider
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"nex/backend/internal/protocol/openai"
)
// Client OpenAI 兼容供应商客户端
type Client struct {
httpClient *http.Client
adapter *openai.Adapter
}
// NewClient 创建供应商客户端
func NewClient() *Client {
return &Client{
httpClient: &http.Client{
Timeout: 30 * time.Second, // 非流式请求超时
},
adapter: openai.NewAdapter(),
}
}
// SendRequest 发送非流式请求
func (c *Client) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) {
// 准备请求
httpReq, err := c.adapter.PrepareRequest(req, apiKey, baseURL)
if err != nil {
return nil, fmt.Errorf("准备请求失败: %w", err)
}
// 调试日志:打印完整请求信息
fmt.Printf("[DEBUG] 请求URL: %s\n", httpReq.URL.String())
fmt.Printf("[DEBUG] 请求Method: %s\n", httpReq.Method)
fmt.Printf("[DEBUG] 请求Headers: %v\n", httpReq.Header)
// 设置上下文
httpReq = httpReq.WithContext(ctx)
// 发送请求
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %w", err)
}
// 检查状态码
if resp.StatusCode != http.StatusOK {
// 解析错误响应
errorResp, parseErr := c.adapter.ParseErrorResponse(resp)
if parseErr != nil {
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
}
return nil, fmt.Errorf("供应商错误: %s", errorResp.Error.Message)
}
// 解析响应
result, err := c.adapter.ParseResponse(resp)
if err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
return result, nil
}
// SendStreamRequest 发送流式请求
func (c *Client) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan StreamEvent, error) {
// 确保请求设置为流式
req.Stream = true
// 准备请求
httpReq, err := c.adapter.PrepareRequest(req, apiKey, baseURL)
if err != nil {
return nil, fmt.Errorf("准备请求失败: %w", err)
}
// 设置上下文
httpReq = httpReq.WithContext(ctx)
// 发送请求
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %w", err)
}
// 检查状态码
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
errorResp, parseErr := c.adapter.ParseErrorResponse(resp)
if parseErr != nil {
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
}
return nil, fmt.Errorf("供应商错误: %s", errorResp.Error.Message)
}
// 创建事件通道
eventChan := make(chan StreamEvent, 100)
// 启动 goroutine 读取流
go c.readStream(ctx, resp.Body, eventChan)
return eventChan, nil
}
// StreamEvent 流事件
type StreamEvent struct {
Data []byte
Error error
Done bool
}
// readStream 读取 SSE 流
func (c *Client) readStream(ctx context.Context, body io.ReadCloser, eventChan chan<- StreamEvent) {
defer close(eventChan)
defer body.Close()
buf := make([]byte, 4096)
var dataBuf []byte
for {
select {
case <-ctx.Done():
eventChan <- StreamEvent{Error: ctx.Err()}
return
default:
}
n, err := body.Read(buf)
if err != nil {
if err == io.EOF {
// 流结束
return
}
eventChan <- StreamEvent{Error: err}
return
}
dataBuf = append(dataBuf, buf[:n]...)
// 处理完整的 SSE 事件
for {
// 查找事件边界(双换行)
idx := bytes.Index(dataBuf, []byte("\n\n"))
if idx == -1 {
break
}
// 提取事件
event := dataBuf[:idx]
dataBuf = dataBuf[idx+2:]
// 解析 data 行
lines := strings.Split(string(event), "\n")
for _, line := range lines {
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
// 检查是否是结束标记
if data == "[DONE]" {
eventChan <- StreamEvent{Done: true}
return
}
// 发送数据
eventChan <- StreamEvent{Data: []byte(data)}
}
}
}
}
}

View File

@@ -0,0 +1,71 @@
package router
import (
"errors"
"fmt"
"nex/backend/internal/config"
)
var (
ErrModelNotFound = errors.New("模型未找到")
ErrModelDisabled = errors.New("模型已禁用")
ErrProviderDisabled = errors.New("供应商已禁用")
)
// RouteResult 路由结果
type RouteResult struct {
Provider *config.Provider
Model *config.Model
}
// Router 模型路由器
type Router struct{}
// NewRouter 创建路由器
func NewRouter() *Router {
return &Router{}
}
// Route 根据模型名称路由到供应商
func (r *Router) Route(modelName string) (*RouteResult, error) {
// 查询模型
models, err := config.ListModels("")
if err != nil {
return nil, fmt.Errorf("查询模型失败: %w", err)
}
// 查找匹配的模型
var targetModel *config.Model
for i := range models {
if models[i].ModelName == modelName {
targetModel = &models[i]
break
}
}
if targetModel == nil {
return nil, ErrModelNotFound
}
// 检查模型是否启用
if !targetModel.Enabled {
return nil, ErrModelDisabled
}
// 查询供应商
provider, err := config.GetProvider(targetModel.ProviderID, false)
if err != nil {
return nil, fmt.Errorf("查询供应商失败: %w", err)
}
// 检查供应商是否启用
if !provider.Enabled {
return nil, ErrProviderDisabled
}
return &RouteResult{
Provider: provider,
Model: targetModel,
}, nil
}