1
0

feat: 初始化 AI Gateway 项目

实现支持 OpenAI 和 Anthropic 双协议的统一大模型 API 网关 MVP 版本,包含:
- OpenAI 和 Anthropic 协议代理
- 供应商和模型管理
- 用量统计
- 前端配置界面
This commit is contained in:
2026-04-15 16:53:28 +08:00
commit 915b004924
53 changed files with 5662 additions and 0 deletions

View File

@@ -0,0 +1,32 @@
package config
import (
"os"
"path/filepath"
)
// GetConfigDir 获取配置目录路径(~/.nex/
func GetConfigDir() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}
configDir := filepath.Join(homeDir, ".nex")
// 确保目录存在
if err := os.MkdirAll(configDir, 0755); err != nil {
return "", err
}
return configDir, nil
}
// GetDBPath 获取数据库文件路径
func GetDBPath() (string, error) {
configDir, err := GetConfigDir()
if err != nil {
return "", err
}
return filepath.Join(configDir, "config.db"), nil
}

View File

@@ -0,0 +1,58 @@
package config
import (
"fmt"
"log"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
var db *gorm.DB
// InitDB 初始化数据库连接并创建表
func InitDB() error {
dbPath, err := GetDBPath()
if err != nil {
return fmt.Errorf("获取数据库路径失败: %w", err)
}
// 打开数据库连接
db, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
})
if err != nil {
return fmt.Errorf("连接数据库失败: %w", err)
}
// 启用 WAL 模式以提升并发性能
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
log.Printf("警告: 启用 WAL 模式失败: %v", err)
}
// 自动迁移表结构
if err := db.AutoMigrate(&Provider{}, &Model{}, &UsageStats{}); err != nil {
return fmt.Errorf("创建表失败: %w", err)
}
log.Printf("数据库初始化成功: %s", dbPath)
return nil
}
// GetDB 获取数据库连接
func GetDB() *gorm.DB {
return db
}
// CloseDB 关闭数据库连接
func CloseDB() error {
if db != nil {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
return nil
}

View File

@@ -0,0 +1,119 @@
package config
import (
"errors"
"time"
"gorm.io/gorm"
)
// CreateModel 创建模型
func CreateModel(model *Model) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
// 验证供应商是否存在
var provider Provider
err := db.First(&provider, "id = ?", model.ProviderID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("供应商不存在")
}
return err
}
model.CreatedAt = time.Now()
return db.Create(model).Error
}
// GetModel 获取模型
func GetModel(id string) (*Model, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var model Model
err := db.First(&model, "id = ?", id).Error
if err != nil {
return nil, err
}
return &model, nil
}
// ListModels 列出模型
func ListModels(providerID string) ([]Model, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var models []Model
var err error
if providerID != "" {
err = db.Where("provider_id = ?", providerID).Find(&models).Error
} else {
err = db.Find(&models).Error
}
if err != nil {
return nil, err
}
return models, nil
}
// UpdateModel 更新模型
func UpdateModel(id string, updates map[string]interface{}) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
// 如果更新了 provider_id验证新供应商是否存在
if providerID, ok := updates["provider_id"].(string); ok {
var provider Provider
err := db.First(&provider, "id = ?", providerID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("供应商不存在")
}
return err
}
}
result := db.Model(&Model{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
// DeleteModel 删除模型
func DeleteModel(id string) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
result := db.Delete(&Model{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}

View File

@@ -0,0 +1,57 @@
package config
import (
"time"
)
// Provider 供应商模型
type Provider struct {
ID string `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"`
APIKey string `gorm:"not null" json:"api_key"`
BaseURL string `gorm:"not null" json:"base_url"`
Enabled bool `gorm:"default:true" json:"enabled"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
}
// Model 模型配置
type Model struct {
ID string `gorm:"primaryKey" json:"id"`
ProviderID string `gorm:"not null;index" json:"provider_id"`
ModelName string `gorm:"not null;index" json:"model_name"`
Enabled bool `gorm:"default:true" json:"enabled"`
CreatedAt time.Time `json:"created_at"`
}
// UsageStats 用量统计
type UsageStats struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
ProviderID string `gorm:"not null;index" json:"provider_id"`
ModelName string `gorm:"not null;index" json:"model_name"`
RequestCount int `gorm:"default:0" json:"request_count"`
Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"`
}
// TableName 指定表名
func (Provider) TableName() string {
return "providers"
}
func (Model) TableName() string {
return "models"
}
func (UsageStats) TableName() string {
return "usage_stats"
}
// MaskAPIKey 掩码 API Key仅显示最后 4 个字符)
func (p *Provider) MaskAPIKey() {
if len(p.APIKey) > 4 {
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
} else {
p.APIKey = "***"
}
}

View File

@@ -0,0 +1,102 @@
package config
import (
"errors"
"time"
"gorm.io/gorm"
)
// CreateProvider 创建供应商
func CreateProvider(provider *Provider) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
provider.CreatedAt = time.Now()
provider.UpdatedAt = time.Now()
return db.Create(provider).Error
}
// GetProvider 获取供应商
func GetProvider(id string, maskKey bool) (*Provider, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var provider Provider
err := db.First(&provider, "id = ?", id).Error
if err != nil {
return nil, err
}
if maskKey {
provider.MaskAPIKey()
}
return &provider, nil
}
// ListProviders 列出所有供应商
func ListProviders() ([]Provider, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var providers []Provider
err := db.Find(&providers).Error
if err != nil {
return nil, err
}
// 掩码所有 API Key
for i := range providers {
providers[i].MaskAPIKey()
}
return providers, nil
}
// UpdateProvider 更新供应商
func UpdateProvider(id string, updates map[string]interface{}) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
updates["updated_at"] = time.Now()
result := db.Model(&Provider{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
// DeleteProvider 删除供应商
func DeleteProvider(id string) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
result := db.Delete(&Provider{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}

View File

@@ -0,0 +1,79 @@
package config
import (
"errors"
"time"
"gorm.io/gorm"
)
// RecordRequest 记录请求统计
func RecordRequest(providerID, modelName string) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
today := time.Now().Format("2006-01-02")
todayTime, _ := time.Parse("2006-01-02", today)
// 使用事务确保并发安全
return db.Transaction(func(tx *gorm.DB) error {
var stats UsageStats
// 查找或创建统计记录
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
providerID, modelName, todayTime).
First(&stats).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// 创建新记录
stats = UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: 1,
Date: todayTime,
}
return tx.Create(&stats).Error
} else if err != nil {
return err
}
// 更新计数
return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error
})
}
// GetStats 查询统计
func GetStats(providerID, modelName string, startDate, endDate *time.Time) ([]UsageStats, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var stats []UsageStats
query := db.Model(&UsageStats{})
if providerID != "" {
query = query.Where("provider_id = ?", providerID)
}
if modelName != "" {
query = query.Where("model_name = ?", modelName)
}
if startDate != nil {
query = query.Where("date >= ?", startDate)
}
if endDate != nil {
query = query.Where("date <= ?", endDate)
}
err := query.Order("date DESC").Find(&stats).Error
if err != nil {
return nil, err
}
return stats, nil
}