feat: 初始化 AI Gateway 项目
实现支持 OpenAI 和 Anthropic 双协议的统一大模型 API 网关 MVP 版本,包含: - OpenAI 和 Anthropic 协议代理 - 供应商和模型管理 - 用量统计 - 前端配置界面
This commit is contained in:
32
backend/internal/config/config.go
Normal file
32
backend/internal/config/config.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// GetConfigDir 获取配置目录路径(~/.nex/)
|
||||
func GetConfigDir() (string, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
configDir := filepath.Join(homeDir, ".nex")
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return configDir, nil
|
||||
}
|
||||
|
||||
// GetDBPath 获取数据库文件路径
|
||||
func GetDBPath() (string, error) {
|
||||
configDir, err := GetConfigDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(configDir, "config.db"), nil
|
||||
}
|
||||
58
backend/internal/config/database.go
Normal file
58
backend/internal/config/database.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var db *gorm.DB
|
||||
|
||||
// InitDB 初始化数据库连接并创建表
|
||||
func InitDB() error {
|
||||
dbPath, err := GetDBPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取数据库路径失败: %w", err)
|
||||
}
|
||||
|
||||
// 打开数据库连接
|
||||
db, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 启用 WAL 模式以提升并发性能
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
log.Printf("警告: 启用 WAL 模式失败: %v", err)
|
||||
}
|
||||
|
||||
// 自动迁移表结构
|
||||
if err := db.AutoMigrate(&Provider{}, &Model{}, &UsageStats{}); err != nil {
|
||||
return fmt.Errorf("创建表失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("数据库初始化成功: %s", dbPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDB 获取数据库连接
|
||||
func GetDB() *gorm.DB {
|
||||
return db
|
||||
}
|
||||
|
||||
// CloseDB 关闭数据库连接
|
||||
func CloseDB() error {
|
||||
if db != nil {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
119
backend/internal/config/model.go
Normal file
119
backend/internal/config/model.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateModel 创建模型
|
||||
func CreateModel(model *Model) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
// 验证供应商是否存在
|
||||
var provider Provider
|
||||
err := db.First(&provider, "id = ?", model.ProviderID).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("供应商不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
model.CreatedAt = time.Now()
|
||||
|
||||
return db.Create(model).Error
|
||||
}
|
||||
|
||||
// GetModel 获取模型
|
||||
func GetModel(id string) (*Model, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
var model Model
|
||||
err := db.First(&model, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model, nil
|
||||
}
|
||||
|
||||
// ListModels 列出模型
|
||||
func ListModels(providerID string) ([]Model, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
var models []Model
|
||||
var err error
|
||||
|
||||
if providerID != "" {
|
||||
err = db.Where("provider_id = ?", providerID).Find(&models).Error
|
||||
} else {
|
||||
err = db.Find(&models).Error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// UpdateModel 更新模型
|
||||
func UpdateModel(id string, updates map[string]interface{}) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
// 如果更新了 provider_id,验证新供应商是否存在
|
||||
if providerID, ok := updates["provider_id"].(string); ok {
|
||||
var provider Provider
|
||||
err := db.First(&provider, "id = ?", providerID).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("供应商不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
result := db.Model(&Model{}).Where("id = ?", id).Updates(updates)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteModel 删除模型
|
||||
func DeleteModel(id string) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
result := db.Delete(&Model{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
57
backend/internal/config/models.go
Normal file
57
backend/internal/config/models.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Provider 供应商模型
|
||||
type Provider struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
APIKey string `gorm:"not null" json:"api_key"`
|
||||
BaseURL string `gorm:"not null" json:"base_url"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
|
||||
}
|
||||
|
||||
// Model 模型配置
|
||||
type Model struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
ProviderID string `gorm:"not null;index" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index" json:"model_name"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// UsageStats 用量统计
|
||||
type UsageStats struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
ProviderID string `gorm:"not null;index" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index" json:"model_name"`
|
||||
RequestCount int `gorm:"default:0" json:"request_count"`
|
||||
Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (Provider) TableName() string {
|
||||
return "providers"
|
||||
}
|
||||
|
||||
func (Model) TableName() string {
|
||||
return "models"
|
||||
}
|
||||
|
||||
func (UsageStats) TableName() string {
|
||||
return "usage_stats"
|
||||
}
|
||||
|
||||
// MaskAPIKey 掩码 API Key(仅显示最后 4 个字符)
|
||||
func (p *Provider) MaskAPIKey() {
|
||||
if len(p.APIKey) > 4 {
|
||||
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
|
||||
} else {
|
||||
p.APIKey = "***"
|
||||
}
|
||||
}
|
||||
102
backend/internal/config/provider.go
Normal file
102
backend/internal/config/provider.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateProvider 创建供应商
|
||||
func CreateProvider(provider *Provider) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
provider.CreatedAt = time.Now()
|
||||
provider.UpdatedAt = time.Now()
|
||||
|
||||
return db.Create(provider).Error
|
||||
}
|
||||
|
||||
// GetProvider 获取供应商
|
||||
func GetProvider(id string, maskKey bool) (*Provider, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
var provider Provider
|
||||
err := db.First(&provider, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if maskKey {
|
||||
provider.MaskAPIKey()
|
||||
}
|
||||
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
// ListProviders 列出所有供应商
|
||||
func ListProviders() ([]Provider, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
var providers []Provider
|
||||
err := db.Find(&providers).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 掩码所有 API Key
|
||||
for i := range providers {
|
||||
providers[i].MaskAPIKey()
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
// UpdateProvider 更新供应商
|
||||
func UpdateProvider(id string, updates map[string]interface{}) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
updates["updated_at"] = time.Now()
|
||||
|
||||
result := db.Model(&Provider{}).Where("id = ?", id).Updates(updates)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteProvider 删除供应商
|
||||
func DeleteProvider(id string) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
result := db.Delete(&Provider{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
79
backend/internal/config/stats.go
Normal file
79
backend/internal/config/stats.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RecordRequest 记录请求统计
|
||||
func RecordRequest(providerID, modelName string) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
today := time.Now().Format("2006-01-02")
|
||||
todayTime, _ := time.Parse("2006-01-02", today)
|
||||
|
||||
// 使用事务确保并发安全
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
var stats UsageStats
|
||||
|
||||
// 查找或创建统计记录
|
||||
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
|
||||
providerID, modelName, todayTime).
|
||||
First(&stats).Error
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// 创建新记录
|
||||
stats = UsageStats{
|
||||
ProviderID: providerID,
|
||||
ModelName: modelName,
|
||||
RequestCount: 1,
|
||||
Date: todayTime,
|
||||
}
|
||||
return tx.Create(&stats).Error
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新计数
|
||||
return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error
|
||||
})
|
||||
}
|
||||
|
||||
// GetStats 查询统计
|
||||
func GetStats(providerID, modelName string, startDate, endDate *time.Time) ([]UsageStats, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
var stats []UsageStats
|
||||
query := db.Model(&UsageStats{})
|
||||
|
||||
if providerID != "" {
|
||||
query = query.Where("provider_id = ?", providerID)
|
||||
}
|
||||
|
||||
if modelName != "" {
|
||||
query = query.Where("model_name = ?", modelName)
|
||||
}
|
||||
|
||||
if startDate != nil {
|
||||
query = query.Where("date >= ?", startDate)
|
||||
}
|
||||
|
||||
if endDate != nil {
|
||||
query = query.Where("date <= ?", endDate)
|
||||
}
|
||||
|
||||
err := query.Order("date DESC").Find(&stats).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
Reference in New Issue
Block a user