feat: 初始化 AI Gateway 项目
实现支持 OpenAI 和 Anthropic 双协议的统一大模型 API 网关 MVP 版本,包含: - OpenAI 和 Anthropic 协议代理 - 供应商和模型管理 - 用量统计 - 前端配置界面
This commit is contained in:
188
backend/README.md
Normal file
188
backend/README.md
Normal file
@@ -0,0 +1,188 @@
|
||||
# AI Gateway Backend
|
||||
|
||||
AI 网关后端服务,提供统一的大模型 API 代理接口。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 支持 OpenAI 协议(`/v1/chat/completions`)
|
||||
- 支持 Anthropic 协议(`/v1/messages`)
|
||||
- 支持流式响应(SSE)
|
||||
- 支持 Function Calling / Tools
|
||||
- 多供应商配置和路由
|
||||
- 用量统计
|
||||
|
||||
## 技术栈
|
||||
|
||||
- **语言**: Go
|
||||
- **HTTP 框架**: Gin
|
||||
- **ORM**: GORM
|
||||
- **数据库**: SQLite
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
backend/
|
||||
├── cmd/
|
||||
│ └── server/
|
||||
│ └── main.go # 主程序入口
|
||||
├── internal/
|
||||
│ ├── config/ # 配置和数据库
|
||||
│ │ ├── config.go # 配置目录管理
|
||||
│ │ ├── database.go # 数据库连接
|
||||
│ │ ├── models.go # 数据模型
|
||||
│ │ ├── provider.go # 供应商 CRUD
|
||||
│ │ ├── model.go # 模型 CRUD
|
||||
│ │ └── stats.go # 统计记录
|
||||
│ ├── handler/ # HTTP 处理器
|
||||
│ │ ├── openai_handler.go
|
||||
│ │ ├── anthropic_handler.go
|
||||
│ │ ├── provider_handler.go
|
||||
│ │ ├── model_handler.go
|
||||
│ │ └── stats_handler.go
|
||||
│ ├── protocol/ # 协议适配器
|
||||
│ │ ├── openai/
|
||||
│ │ │ ├── types.go
|
||||
│ │ │ └── adapter.go
|
||||
│ │ └── anthropic/
|
||||
│ │ ├── types.go
|
||||
│ │ ├── converter.go
|
||||
│ │ └── stream_converter.go
|
||||
│ ├── provider/ # 供应商客户端
|
||||
│ │ └── client.go
|
||||
│ └── router/ # 模型路由
|
||||
│ └── model_router.go
|
||||
├── go.mod
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## 运行方式
|
||||
|
||||
### 安装依赖
|
||||
|
||||
```bash
|
||||
go mod download
|
||||
```
|
||||
|
||||
### 启动服务
|
||||
|
||||
```bash
|
||||
go run cmd/server/main.go
|
||||
```
|
||||
|
||||
服务将在端口 9826 启动。
|
||||
|
||||
## API 文档
|
||||
|
||||
### 代理接口
|
||||
|
||||
#### OpenAI Chat Completions
|
||||
|
||||
```
|
||||
POST /v1/chat/completions
|
||||
```
|
||||
|
||||
请求示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
#### Anthropic Messages
|
||||
|
||||
```
|
||||
POST /v1/messages
|
||||
```
|
||||
|
||||
请求示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "claude-3-opus",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hello"}]}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 管理接口
|
||||
|
||||
#### 供应商管理
|
||||
|
||||
- `GET /api/providers` - 列出所有供应商
|
||||
- `POST /api/providers` - 创建供应商
|
||||
- `GET /api/providers/:id` - 获取供应商
|
||||
- `PUT /api/providers/:id` - 更新供应商
|
||||
- `DELETE /api/providers/:id` - 删除供应商
|
||||
|
||||
创建供应商示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "openai",
|
||||
"name": "OpenAI",
|
||||
"api_key": "sk-...",
|
||||
"base_url": "https://api.openai.com/v1"
|
||||
}
|
||||
```
|
||||
|
||||
**重要说明:**
|
||||
- `base_url` 应配置到 API 版本路径,不包含具体端点
|
||||
- OpenAI: `https://api.openai.com/v1`
|
||||
- GLM: `https://open.bigmodel.cn/api/paas/v4`
|
||||
- 其他 OpenAI 兼容供应商根据其文档配置版本路径
|
||||
|
||||
#### 模型管理
|
||||
|
||||
- `GET /api/models` - 列出模型(支持 `?provider_id=xxx` 过滤)
|
||||
- `POST /api/models` - 创建模型
|
||||
- `GET /api/models/:id` - 获取模型
|
||||
- `PUT /api/models/:id` - 更新模型
|
||||
- `DELETE /api/models/:id` - 删除模型
|
||||
|
||||
创建模型示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "gpt-4",
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4"
|
||||
}
|
||||
```
|
||||
|
||||
#### 统计查询
|
||||
|
||||
- `GET /api/stats` - 查询统计
|
||||
- `GET /api/stats/aggregate` - 聚合统计
|
||||
|
||||
查询参数:
|
||||
|
||||
- `provider_id` - 供应商 ID
|
||||
- `model_name` - 模型名称
|
||||
- `start_date` - 开始日期(YYYY-MM-DD)
|
||||
- `end_date` - 结束日期(YYYY-MM-DD)
|
||||
- `group_by` - 聚合维度(provider/model/date)
|
||||
|
||||
## 配置
|
||||
|
||||
配置和数据存储在 `~/.nex/` 目录:
|
||||
|
||||
- `~/.nex/config.db` - SQLite 数据库
|
||||
|
||||
## 开发
|
||||
|
||||
### 构建
|
||||
|
||||
```bash
|
||||
go build -o ai-gateway cmd/server/main.go
|
||||
```
|
||||
|
||||
### 环境要求
|
||||
|
||||
- Go 1.21 或更高版本
|
||||
119
backend/cmd/server/main.go
Normal file
119
backend/cmd/server/main.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/handler"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 初始化数据库
|
||||
if err := config.InitDB(); err != nil {
|
||||
log.Fatalf("初始化数据库失败: %v", err)
|
||||
}
|
||||
defer config.CloseDB()
|
||||
|
||||
// 创建 Gin 引擎
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.Default()
|
||||
|
||||
// 配置 CORS
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
})
|
||||
|
||||
// 注册路由
|
||||
setupRoutes(r)
|
||||
|
||||
// 创建 HTTP 服务器
|
||||
srv := &http.Server{
|
||||
Addr: ":9826",
|
||||
Handler: r,
|
||||
}
|
||||
|
||||
// 启动服务器(在 goroutine 中)
|
||||
go func() {
|
||||
log.Printf("AI Gateway 启动在端口 9826")
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("服务器启动失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待中断信号以优雅关闭服务器
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
log.Println("正在关闭服务器...")
|
||||
|
||||
// 给服务器 5 秒时间完成当前请求
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
log.Fatal("服务器强制关闭:", err)
|
||||
}
|
||||
|
||||
log.Println("服务器已关闭")
|
||||
}
|
||||
|
||||
// setupRoutes 配置路由
|
||||
func setupRoutes(r *gin.Engine) {
|
||||
// OpenAI 协议代理
|
||||
openaiHandler := handler.NewOpenAIHandler()
|
||||
r.POST("/v1/chat/completions", openaiHandler.HandleChatCompletions)
|
||||
|
||||
// Anthropic 协议代理
|
||||
anthropicHandler := handler.NewAnthropicHandler()
|
||||
r.POST("/v1/messages", anthropicHandler.HandleMessages)
|
||||
|
||||
// 供应商管理 API
|
||||
providerHandler := handler.NewProviderHandler()
|
||||
providers := r.Group("/api/providers")
|
||||
{
|
||||
providers.GET("", providerHandler.ListProviders)
|
||||
providers.POST("", providerHandler.CreateProvider)
|
||||
providers.GET("/:id", providerHandler.GetProvider)
|
||||
providers.PUT("/:id", providerHandler.UpdateProvider)
|
||||
providers.DELETE("/:id", providerHandler.DeleteProvider)
|
||||
}
|
||||
|
||||
// 模型管理 API
|
||||
modelHandler := handler.NewModelHandler()
|
||||
models := r.Group("/api/models")
|
||||
{
|
||||
models.GET("", modelHandler.ListModels)
|
||||
models.POST("", modelHandler.CreateModel)
|
||||
models.GET("/:id", modelHandler.GetModel)
|
||||
models.PUT("/:id", modelHandler.UpdateModel)
|
||||
models.DELETE("/:id", modelHandler.DeleteModel)
|
||||
}
|
||||
|
||||
// 统计查询 API
|
||||
statsHandler := handler.NewStatsHandler()
|
||||
stats := r.Group("/api/stats")
|
||||
{
|
||||
stats.GET("", statsHandler.GetStats)
|
||||
stats.GET("/aggregate", statsHandler.AggregateStats)
|
||||
}
|
||||
|
||||
// 健康检查
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
}
|
||||
41
backend/go.mod
Normal file
41
backend/go.mod
Normal file
@@ -0,0 +1,41 @@
|
||||
module nex/backend
|
||||
|
||||
go 1.26.2
|
||||
|
||||
require (
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/bytedance/sonic v1.15.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/gin-gonic/gin v1.12.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.1 // indirect
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
||||
golang.org/x/arch v0.22.0 // indirect
|
||||
golang.org/x/crypto v0.48.0 // indirect
|
||||
golang.org/x/net v0.51.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gorm.io/driver/sqlite v1.6.0 // indirect
|
||||
gorm.io/gorm v1.31.1 // indirect
|
||||
)
|
||||
88
backend/go.sum
Normal file
88
backend/go.sum
Normal file
@@ -0,0 +1,88 @@
|
||||
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
|
||||
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
|
||||
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
|
||||
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
|
||||
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
|
||||
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
|
||||
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
|
||||
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
|
||||
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY=
|
||||
github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
||||
32
backend/internal/config/config.go
Normal file
32
backend/internal/config/config.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// GetConfigDir 获取配置目录路径(~/.nex/)
|
||||
func GetConfigDir() (string, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
configDir := filepath.Join(homeDir, ".nex")
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return configDir, nil
|
||||
}
|
||||
|
||||
// GetDBPath 获取数据库文件路径
|
||||
func GetDBPath() (string, error) {
|
||||
configDir, err := GetConfigDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(configDir, "config.db"), nil
|
||||
}
|
||||
58
backend/internal/config/database.go
Normal file
58
backend/internal/config/database.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var db *gorm.DB
|
||||
|
||||
// InitDB 初始化数据库连接并创建表
|
||||
func InitDB() error {
|
||||
dbPath, err := GetDBPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取数据库路径失败: %w", err)
|
||||
}
|
||||
|
||||
// 打开数据库连接
|
||||
db, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 启用 WAL 模式以提升并发性能
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
log.Printf("警告: 启用 WAL 模式失败: %v", err)
|
||||
}
|
||||
|
||||
// 自动迁移表结构
|
||||
if err := db.AutoMigrate(&Provider{}, &Model{}, &UsageStats{}); err != nil {
|
||||
return fmt.Errorf("创建表失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("数据库初始化成功: %s", dbPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDB 获取数据库连接
|
||||
func GetDB() *gorm.DB {
|
||||
return db
|
||||
}
|
||||
|
||||
// CloseDB 关闭数据库连接
|
||||
func CloseDB() error {
|
||||
if db != nil {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
119
backend/internal/config/model.go
Normal file
119
backend/internal/config/model.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateModel 创建模型
|
||||
func CreateModel(model *Model) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
// 验证供应商是否存在
|
||||
var provider Provider
|
||||
err := db.First(&provider, "id = ?", model.ProviderID).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("供应商不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
model.CreatedAt = time.Now()
|
||||
|
||||
return db.Create(model).Error
|
||||
}
|
||||
|
||||
// GetModel 获取模型
|
||||
func GetModel(id string) (*Model, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
var model Model
|
||||
err := db.First(&model, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model, nil
|
||||
}
|
||||
|
||||
// ListModels 列出模型
|
||||
func ListModels(providerID string) ([]Model, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
var models []Model
|
||||
var err error
|
||||
|
||||
if providerID != "" {
|
||||
err = db.Where("provider_id = ?", providerID).Find(&models).Error
|
||||
} else {
|
||||
err = db.Find(&models).Error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// UpdateModel 更新模型
|
||||
func UpdateModel(id string, updates map[string]interface{}) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
// 如果更新了 provider_id,验证新供应商是否存在
|
||||
if providerID, ok := updates["provider_id"].(string); ok {
|
||||
var provider Provider
|
||||
err := db.First(&provider, "id = ?", providerID).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("供应商不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
result := db.Model(&Model{}).Where("id = ?", id).Updates(updates)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteModel 删除模型
|
||||
func DeleteModel(id string) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
result := db.Delete(&Model{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
57
backend/internal/config/models.go
Normal file
57
backend/internal/config/models.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Provider 供应商模型
|
||||
type Provider struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
APIKey string `gorm:"not null" json:"api_key"`
|
||||
BaseURL string `gorm:"not null" json:"base_url"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
|
||||
}
|
||||
|
||||
// Model 模型配置
|
||||
type Model struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
ProviderID string `gorm:"not null;index" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index" json:"model_name"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// UsageStats 用量统计
|
||||
type UsageStats struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
ProviderID string `gorm:"not null;index" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index" json:"model_name"`
|
||||
RequestCount int `gorm:"default:0" json:"request_count"`
|
||||
Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (Provider) TableName() string {
|
||||
return "providers"
|
||||
}
|
||||
|
||||
func (Model) TableName() string {
|
||||
return "models"
|
||||
}
|
||||
|
||||
func (UsageStats) TableName() string {
|
||||
return "usage_stats"
|
||||
}
|
||||
|
||||
// MaskAPIKey 掩码 API Key(仅显示最后 4 个字符)
|
||||
func (p *Provider) MaskAPIKey() {
|
||||
if len(p.APIKey) > 4 {
|
||||
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
|
||||
} else {
|
||||
p.APIKey = "***"
|
||||
}
|
||||
}
|
||||
102
backend/internal/config/provider.go
Normal file
102
backend/internal/config/provider.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateProvider 创建供应商
|
||||
func CreateProvider(provider *Provider) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
provider.CreatedAt = time.Now()
|
||||
provider.UpdatedAt = time.Now()
|
||||
|
||||
return db.Create(provider).Error
|
||||
}
|
||||
|
||||
// GetProvider 获取供应商
|
||||
func GetProvider(id string, maskKey bool) (*Provider, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
var provider Provider
|
||||
err := db.First(&provider, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if maskKey {
|
||||
provider.MaskAPIKey()
|
||||
}
|
||||
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
// ListProviders 列出所有供应商
|
||||
func ListProviders() ([]Provider, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
var providers []Provider
|
||||
err := db.Find(&providers).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 掩码所有 API Key
|
||||
for i := range providers {
|
||||
providers[i].MaskAPIKey()
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
// UpdateProvider 更新供应商
|
||||
func UpdateProvider(id string, updates map[string]interface{}) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
updates["updated_at"] = time.Now()
|
||||
|
||||
result := db.Model(&Provider{}).Where("id = ?", id).Updates(updates)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteProvider 删除供应商
|
||||
func DeleteProvider(id string) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
result := db.Delete(&Provider{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
79
backend/internal/config/stats.go
Normal file
79
backend/internal/config/stats.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RecordRequest 记录请求统计
|
||||
func RecordRequest(providerID, modelName string) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
today := time.Now().Format("2006-01-02")
|
||||
todayTime, _ := time.Parse("2006-01-02", today)
|
||||
|
||||
// 使用事务确保并发安全
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
var stats UsageStats
|
||||
|
||||
// 查找或创建统计记录
|
||||
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
|
||||
providerID, modelName, todayTime).
|
||||
First(&stats).Error
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// 创建新记录
|
||||
stats = UsageStats{
|
||||
ProviderID: providerID,
|
||||
ModelName: modelName,
|
||||
RequestCount: 1,
|
||||
Date: todayTime,
|
||||
}
|
||||
return tx.Create(&stats).Error
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新计数
|
||||
return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error
|
||||
})
|
||||
}
|
||||
|
||||
// GetStats 查询统计
|
||||
func GetStats(providerID, modelName string, startDate, endDate *time.Time) ([]UsageStats, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
var stats []UsageStats
|
||||
query := db.Model(&UsageStats{})
|
||||
|
||||
if providerID != "" {
|
||||
query = query.Where("provider_id = ?", providerID)
|
||||
}
|
||||
|
||||
if modelName != "" {
|
||||
query = query.Where("model_name = ?", modelName)
|
||||
}
|
||||
|
||||
if startDate != nil {
|
||||
query = query.Where("date >= ?", startDate)
|
||||
}
|
||||
|
||||
if endDate != nil {
|
||||
query = query.Where("date <= ?", endDate)
|
||||
}
|
||||
|
||||
err := query.Order("date DESC").Find(&stats).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
243
backend/internal/handler/anthropic_handler.go
Normal file
243
backend/internal/handler/anthropic_handler.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/protocol/anthropic"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/router"
|
||||
)
|
||||
|
||||
// AnthropicHandler Anthropic 协议处理器
|
||||
type AnthropicHandler struct {
|
||||
client *provider.Client
|
||||
router *router.Router
|
||||
}
|
||||
|
||||
// NewAnthropicHandler 创建 Anthropic 处理器
|
||||
func NewAnthropicHandler() *AnthropicHandler {
|
||||
return &AnthropicHandler{
|
||||
client: provider.NewClient(),
|
||||
router: router.NewRouter(),
|
||||
}
|
||||
}
|
||||
|
||||
// HandleMessages 处理 Messages 请求
|
||||
func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
|
||||
// 解析 Anthropic 请求
|
||||
var req anthropic.MessagesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: "无效的请求格式: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查多模态内容
|
||||
if err := h.checkMultimodalContent(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为 OpenAI 请求
|
||||
openaiReq, err := anthropic.ConvertRequest(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: "请求转换失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 路由到供应商
|
||||
routeResult, err := h.router.Route(openaiReq.Model)
|
||||
if err != nil {
|
||||
h.handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 根据是否流式选择处理方式
|
||||
if req.Stream {
|
||||
h.handleStreamRequest(c, openaiReq, routeResult)
|
||||
} else {
|
||||
h.handleNonStreamRequest(c, openaiReq, routeResult)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStreamRequest 处理非流式请求
|
||||
func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送请求到供应商
|
||||
openaiResp, err := h.client.SendRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为 Anthropic 响应
|
||||
anthropicResp, err := anthropic.ConvertResponse(openaiResp)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: "响应转换失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
|
||||
}()
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, anthropicResp)
|
||||
}
|
||||
|
||||
// handleStreamRequest 处理流式请求
|
||||
func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送流式请求到供应商
|
||||
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 SSE 响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
// 创建流写入器
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
// 创建流式转换器
|
||||
converter := anthropic.NewStreamConverter(
|
||||
fmt.Sprintf("msg_%s", routeResult.Provider.ID),
|
||||
openaiReq.Model,
|
||||
)
|
||||
|
||||
// 流式转发事件
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
fmt.Printf("流错误: %v\n", event.Error)
|
||||
break
|
||||
}
|
||||
|
||||
if event.Done {
|
||||
break
|
||||
}
|
||||
|
||||
// 解析 OpenAI 流块
|
||||
chunk, err := openai.NewAdapter().ParseStreamChunk(event.Data)
|
||||
if err != nil {
|
||||
fmt.Printf("解析流块失败: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 转换为 Anthropic 事件
|
||||
anthropicEvents, err := converter.ConvertChunk(chunk)
|
||||
if err != nil {
|
||||
fmt.Printf("转换事件失败: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 写入事件
|
||||
for _, ae := range anthropicEvents {
|
||||
eventStr, err := anthropic.SerializeEvent(ae)
|
||||
if err != nil {
|
||||
fmt.Printf("序列化事件失败: %v\n", err)
|
||||
continue
|
||||
}
|
||||
writer.WriteString(eventStr)
|
||||
writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
|
||||
}()
|
||||
}
|
||||
|
||||
// checkMultimodalContent 检查多模态内容
|
||||
func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest) error {
|
||||
for _, msg := range req.Messages {
|
||||
for _, block := range msg.Content {
|
||||
if block.Type == "image" {
|
||||
return fmt.Errorf("MVP 不支持多模态内容(图片)")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleError 处理路由错误
|
||||
func (h *AnthropicHandler) handleError(c *gin.Context, err error) {
|
||||
switch err {
|
||||
case router.ErrModelNotFound:
|
||||
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "not_found_error",
|
||||
Message: "模型未找到",
|
||||
},
|
||||
})
|
||||
case router.ErrModelDisabled:
|
||||
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "not_found_error",
|
||||
Message: "模型已禁用",
|
||||
},
|
||||
})
|
||||
case router.ErrProviderDisabled:
|
||||
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "not_found_error",
|
||||
Message: "供应商已禁用",
|
||||
},
|
||||
})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "internal_error",
|
||||
Message: "内部错误: " + err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
161
backend/internal/handler/model_handler.go
Normal file
161
backend/internal/handler/model_handler.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
)
|
||||
|
||||
// ModelHandler 模型管理处理器
|
||||
type ModelHandler struct{}
|
||||
|
||||
// NewModelHandler 创建模型处理器
|
||||
func NewModelHandler() *ModelHandler {
|
||||
return &ModelHandler{}
|
||||
}
|
||||
|
||||
// CreateModel 创建模型
|
||||
func (h *ModelHandler) CreateModel(c *gin.Context) {
|
||||
var req struct {
|
||||
ID string `json:"id" binding:"required"`
|
||||
ProviderID string `json:"provider_id" binding:"required"`
|
||||
ModelName string `json:"model_name" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "缺少必需字段: id, provider_id, model_name",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 创建模型对象
|
||||
model := &config.Model{
|
||||
ID: req.ID,
|
||||
ProviderID: req.ProviderID,
|
||||
ModelName: req.ModelName,
|
||||
Enabled: true, // 默认启用
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
err := config.CreateModel(model)
|
||||
if err != nil {
|
||||
if err.Error() == "供应商不存在" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "创建模型失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, model)
|
||||
}
|
||||
|
||||
// ListModels 列出模型
|
||||
func (h *ModelHandler) ListModels(c *gin.Context) {
|
||||
providerID := c.Query("provider_id")
|
||||
|
||||
models, err := config.ListModels(providerID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询模型失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models)
|
||||
}
|
||||
|
||||
// GetModel 获取模型
|
||||
func (h *ModelHandler) GetModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
model, err := config.GetModel(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询模型失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model)
|
||||
}
|
||||
|
||||
// UpdateModel 更新模型
|
||||
func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
var req map[string]interface{}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "无效的请求格式",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新模型
|
||||
err := config.UpdateModel(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err.Error() == "供应商不存在" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "更新模型失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的模型
|
||||
model, err := config.GetModel(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询更新后的模型失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model)
|
||||
}
|
||||
|
||||
// DeleteModel 删除模型
|
||||
func (h *ModelHandler) DeleteModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
err := config.DeleteModel(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "删除模型失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
167
backend/internal/handler/openai_handler.go
Normal file
167
backend/internal/handler/openai_handler.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/router"
|
||||
)
|
||||
|
||||
// OpenAIHandler OpenAI 协议处理器
|
||||
type OpenAIHandler struct {
|
||||
client *provider.Client
|
||||
router *router.Router
|
||||
}
|
||||
|
||||
// NewOpenAIHandler 创建 OpenAI 处理器
|
||||
func NewOpenAIHandler() *OpenAIHandler {
|
||||
return &OpenAIHandler{
|
||||
client: provider.NewClient(),
|
||||
router: router.NewRouter(),
|
||||
}
|
||||
}
|
||||
|
||||
// HandleChatCompletions 处理 Chat Completions 请求
|
||||
func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
|
||||
// 解析请求
|
||||
var req openai.ChatCompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "无效的请求格式: " + err.Error(),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 路由到供应商
|
||||
routeResult, err := h.router.Route(req.Model)
|
||||
if err != nil {
|
||||
h.handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 根据是否流式选择处理方式
|
||||
if req.Stream {
|
||||
h.handleStreamRequest(c, &req, routeResult)
|
||||
} else {
|
||||
h.handleNonStreamRequest(c, &req, routeResult)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStreamRequest 处理非流式请求
|
||||
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送请求到供应商
|
||||
resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
Type: "api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// handleStreamRequest 处理流式请求
|
||||
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送流式请求到供应商
|
||||
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
Type: "api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 SSE 响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
// 创建流写入器
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
// 流式转发事件
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
// 流错误,记录日志
|
||||
fmt.Printf("流错误: %v\n", event.Error)
|
||||
break
|
||||
}
|
||||
|
||||
if event.Done {
|
||||
// 流结束
|
||||
writer.WriteString("data: [DONE]\n\n")
|
||||
writer.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
// 写入事件数据
|
||||
writer.WriteString("data: ")
|
||||
writer.Write(event.Data)
|
||||
writer.WriteString("\n\n")
|
||||
writer.Flush()
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
}
|
||||
|
||||
// handleError 处理路由错误
|
||||
func (h *OpenAIHandler) handleError(c *gin.Context, err error) {
|
||||
switch err {
|
||||
case router.ErrModelNotFound:
|
||||
c.JSON(http.StatusNotFound, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "模型未找到",
|
||||
Type: "invalid_request_error",
|
||||
Code: "model_not_found",
|
||||
},
|
||||
})
|
||||
case router.ErrModelDisabled:
|
||||
c.JSON(http.StatusNotFound, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "模型已禁用",
|
||||
Type: "invalid_request_error",
|
||||
Code: "model_disabled",
|
||||
},
|
||||
})
|
||||
case router.ErrProviderDisabled:
|
||||
c.JSON(http.StatusNotFound, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "供应商已禁用",
|
||||
Type: "invalid_request_error",
|
||||
Code: "provider_disabled",
|
||||
},
|
||||
})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "内部错误: " + err.Error(),
|
||||
Type: "internal_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
167
backend/internal/handler/provider_handler.go
Normal file
167
backend/internal/handler/provider_handler.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
)
|
||||
|
||||
// ProviderHandler 供应商管理处理器
|
||||
type ProviderHandler struct{}
|
||||
|
||||
// NewProviderHandler 创建供应商处理器
|
||||
func NewProviderHandler() *ProviderHandler {
|
||||
return &ProviderHandler{}
|
||||
}
|
||||
|
||||
// CreateProvider 创建供应商
|
||||
func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
var req struct {
|
||||
ID string `json:"id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "缺少必需字段: id, name, api_key, base_url",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 创建供应商对象
|
||||
provider := &config.Provider{
|
||||
ID: req.ID,
|
||||
Name: req.Name,
|
||||
APIKey: req.APIKey,
|
||||
BaseURL: req.BaseURL,
|
||||
Enabled: true, // 默认启用
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
err := config.CreateProvider(provider)
|
||||
if err != nil {
|
||||
// 检查是否是唯一约束错误(ID 重复)
|
||||
if err.Error() == "UNIQUE constraint failed: providers.id" {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "供应商 ID 已存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "创建供应商失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 掩码 API Key 后返回
|
||||
provider.MaskAPIKey()
|
||||
c.JSON(http.StatusCreated, provider)
|
||||
}
|
||||
|
||||
// ListProviders 列出所有供应商
|
||||
func (h *ProviderHandler) ListProviders(c *gin.Context) {
|
||||
providers, err := config.ListProviders()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询供应商失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, providers)
|
||||
}
|
||||
|
||||
// GetProvider 获取供应商
|
||||
func (h *ProviderHandler) GetProvider(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
provider, err := config.GetProvider(id, true) // 掩码 API Key
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询供应商失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, provider)
|
||||
}
|
||||
|
||||
// UpdateProvider 更新供应商
|
||||
func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
var req map[string]interface{}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "无效的请求格式",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新供应商
|
||||
err := config.UpdateProvider(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "更新供应商失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的供应商
|
||||
provider, err := config.GetProvider(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询更新后的供应商失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, provider)
|
||||
}
|
||||
|
||||
// DeleteProvider 删除供应商
|
||||
func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
// 删除供应商(级联删除模型)
|
||||
err := config.DeleteProvider(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "删除供应商失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 删除关联的模型
|
||||
models, _ := config.ListModels("")
|
||||
for _, model := range models {
|
||||
if model.ProviderID == id {
|
||||
_ = config.DeleteModel(model.ID)
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
184
backend/internal/handler/stats_handler.go
Normal file
184
backend/internal/handler/stats_handler.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
)
|
||||
|
||||
// StatsHandler 统计处理器
|
||||
type StatsHandler struct{}
|
||||
|
||||
// NewStatsHandler 创建统计处理器
|
||||
func NewStatsHandler() *StatsHandler {
|
||||
return &StatsHandler{}
|
||||
}
|
||||
|
||||
// GetStats 查询统计
|
||||
func (h *StatsHandler) GetStats(c *gin.Context) {
|
||||
// 解析查询参数
|
||||
providerID := c.Query("provider_id")
|
||||
modelName := c.Query("model_name")
|
||||
startDateStr := c.Query("start_date")
|
||||
endDateStr := c.Query("end_date")
|
||||
|
||||
var startDate, endDate *time.Time
|
||||
|
||||
// 解析日期
|
||||
if startDateStr != "" {
|
||||
t, err := time.Parse("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "无效的 start_date 格式,应为 YYYY-MM-DD",
|
||||
})
|
||||
return
|
||||
}
|
||||
startDate = &t
|
||||
}
|
||||
|
||||
if endDateStr != "" {
|
||||
t, err := time.Parse("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "无效的 end_date 格式,应为 YYYY-MM-DD",
|
||||
})
|
||||
return
|
||||
}
|
||||
endDate = &t
|
||||
}
|
||||
|
||||
// 查询统计
|
||||
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询统计失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// AggregateStats 聚合统计
|
||||
func (h *StatsHandler) AggregateStats(c *gin.Context) {
|
||||
// 解析查询参数
|
||||
providerID := c.Query("provider_id")
|
||||
modelName := c.Query("model_name")
|
||||
startDateStr := c.Query("start_date")
|
||||
endDateStr := c.Query("end_date")
|
||||
groupBy := c.Query("group_by") // "provider", "model", "date"
|
||||
|
||||
var startDate, endDate *time.Time
|
||||
|
||||
// 解析日期
|
||||
if startDateStr != "" {
|
||||
t, err := time.Parse("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "无效的 start_date 格式,应为 YYYY-MM-DD",
|
||||
})
|
||||
return
|
||||
}
|
||||
startDate = &t
|
||||
}
|
||||
|
||||
if endDateStr != "" {
|
||||
t, err := time.Parse("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "无效的 end_date 格式,应为 YYYY-MM-DD",
|
||||
})
|
||||
return
|
||||
}
|
||||
endDate = &t
|
||||
}
|
||||
|
||||
// 查询统计
|
||||
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询统计失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 聚合
|
||||
result := h.aggregate(stats, groupBy)
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// aggregate 执行聚合
|
||||
func (h *StatsHandler) aggregate(stats []config.UsageStats, groupBy string) []map[string]interface{} {
|
||||
switch groupBy {
|
||||
case "provider":
|
||||
return h.aggregateByProvider(stats)
|
||||
case "model":
|
||||
return h.aggregateByModel(stats)
|
||||
case "date":
|
||||
return h.aggregateByDate(stats)
|
||||
default:
|
||||
// 默认按供应商聚合
|
||||
return h.aggregateByProvider(stats)
|
||||
}
|
||||
}
|
||||
|
||||
// aggregateByProvider 按供应商聚合
|
||||
func (h *StatsHandler) aggregateByProvider(stats []config.UsageStats) []map[string]interface{} {
|
||||
aggregated := make(map[string]int)
|
||||
for _, stat := range stats {
|
||||
aggregated[stat.ProviderID] += stat.RequestCount
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(aggregated))
|
||||
for providerID, count := range aggregated {
|
||||
result = append(result, map[string]interface{}{
|
||||
"provider_id": providerID,
|
||||
"request_count": count,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// aggregateByModel 按模型聚合
|
||||
func (h *StatsHandler) aggregateByModel(stats []config.UsageStats) []map[string]interface{} {
|
||||
aggregated := make(map[string]int)
|
||||
for _, stat := range stats {
|
||||
key := stat.ProviderID + "/" + stat.ModelName
|
||||
aggregated[key] += stat.RequestCount
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(aggregated))
|
||||
for key, count := range aggregated {
|
||||
result = append(result, map[string]interface{}{
|
||||
"provider_id": key[:len(key)/2],
|
||||
"model_name": key[len(key)/2+1:],
|
||||
"request_count": count,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// aggregateByDate 按日期聚合
|
||||
func (h *StatsHandler) aggregateByDate(stats []config.UsageStats) []map[string]interface{} {
|
||||
aggregated := make(map[string]int)
|
||||
for _, stat := range stats {
|
||||
key := stat.Date.Format("2006-01-02")
|
||||
aggregated[key] += stat.RequestCount
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(aggregated))
|
||||
for date, count := range aggregated {
|
||||
result = append(result, map[string]interface{}{
|
||||
"date": date,
|
||||
"request_count": count,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
234
backend/internal/protocol/anthropic/converter.go
Normal file
234
backend/internal/protocol/anthropic/converter.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"nex/backend/internal/protocol/openai"
|
||||
)
|
||||
|
||||
// ConvertRequest 将 Anthropic 请求转换为 OpenAI 请求
|
||||
func ConvertRequest(anthropicReq *MessagesRequest) (*openai.ChatCompletionRequest, error) {
|
||||
openaiReq := &openai.ChatCompletionRequest{
|
||||
Model: anthropicReq.Model,
|
||||
Temperature: anthropicReq.Temperature,
|
||||
TopP: anthropicReq.TopP,
|
||||
Stream: anthropicReq.Stream,
|
||||
}
|
||||
|
||||
// 处理 max_tokens(Anthropic 要求必须有,默认 4096)
|
||||
if anthropicReq.MaxTokens > 0 {
|
||||
openaiReq.MaxTokens = &anthropicReq.MaxTokens
|
||||
} else {
|
||||
defaultMax := 4096
|
||||
openaiReq.MaxTokens = &defaultMax
|
||||
}
|
||||
|
||||
// 处理 stop_sequences
|
||||
if len(anthropicReq.StopSequences) > 0 {
|
||||
openaiReq.Stop = anthropicReq.StopSequences
|
||||
}
|
||||
|
||||
// 转换 system 消息
|
||||
messages := make([]openai.Message, 0)
|
||||
if anthropicReq.System != "" {
|
||||
messages = append(messages, openai.Message{
|
||||
Role: "system",
|
||||
Content: anthropicReq.System,
|
||||
})
|
||||
}
|
||||
|
||||
// 转换 messages
|
||||
for _, msg := range anthropicReq.Messages {
|
||||
openaiMsg, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, openaiMsg...)
|
||||
}
|
||||
openaiReq.Messages = messages
|
||||
|
||||
// 转换 tools
|
||||
if len(anthropicReq.Tools) > 0 {
|
||||
openaiReq.Tools = make([]openai.Tool, len(anthropicReq.Tools))
|
||||
for i, tool := range anthropicReq.Tools {
|
||||
openaiReq.Tools[i] = openai.Tool{
|
||||
Type: "function",
|
||||
Function: openai.FunctionDefinition{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: tool.InputSchema,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 转换 tool_choice
|
||||
if anthropicReq.ToolChoice != nil {
|
||||
toolChoice, err := convertToolChoice(anthropicReq.ToolChoice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
openaiReq.ToolChoice = toolChoice
|
||||
}
|
||||
|
||||
return openaiReq, nil
|
||||
}
|
||||
|
||||
// ConvertResponse 将 OpenAI 响应转换为 Anthropic 响应
|
||||
func ConvertResponse(openaiResp *openai.ChatCompletionResponse) (*MessagesResponse, error) {
|
||||
anthropicResp := &MessagesResponse{
|
||||
ID: openaiResp.ID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: openaiResp.Model,
|
||||
Usage: Usage{
|
||||
InputTokens: openaiResp.Usage.PromptTokens,
|
||||
OutputTokens: openaiResp.Usage.CompletionTokens,
|
||||
},
|
||||
}
|
||||
|
||||
// 转换 content
|
||||
if len(openaiResp.Choices) > 0 {
|
||||
choice := openaiResp.Choices[0]
|
||||
content := make([]ContentBlock, 0)
|
||||
|
||||
if choice.Message != nil {
|
||||
// 文本内容
|
||||
if choice.Message.Content != "" {
|
||||
if str, ok := choice.Message.Content.(string); ok && str != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "text",
|
||||
Text: str,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Tool calls
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
// 解析 arguments JSON
|
||||
var input interface{}
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &input); err != nil {
|
||||
return nil, fmt.Errorf("解析 tool_call arguments 失败: %w", err)
|
||||
}
|
||||
|
||||
content = append(content, ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: input,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anthropicResp.Content = content
|
||||
|
||||
// 转换 finish_reason
|
||||
switch choice.FinishReason {
|
||||
case "stop":
|
||||
anthropicResp.StopReason = "end_turn"
|
||||
case "tool_calls":
|
||||
anthropicResp.StopReason = "tool_use"
|
||||
case "length":
|
||||
anthropicResp.StopReason = "max_tokens"
|
||||
}
|
||||
}
|
||||
|
||||
return anthropicResp, nil
|
||||
}
|
||||
|
||||
// convertMessage 转换单条消息
|
||||
func convertMessage(msg AnthropicMessage) ([]openai.Message, error) {
|
||||
var messages []openai.Message
|
||||
|
||||
// 处理 content
|
||||
for _, block := range msg.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
// 文本内容
|
||||
messages = append(messages, openai.Message{
|
||||
Role: msg.Role,
|
||||
Content: block.Text,
|
||||
})
|
||||
|
||||
case "tool_result":
|
||||
// 工具结果
|
||||
content := ""
|
||||
if str, ok := block.Content.(string); ok {
|
||||
content = str
|
||||
} else {
|
||||
// 如果是数组或其他类型,序列化为 JSON
|
||||
bytes, err := json.Marshal(block.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化 tool_result 内容失败: %w", err)
|
||||
}
|
||||
content = string(bytes)
|
||||
}
|
||||
|
||||
messages = append(messages, openai.Message{
|
||||
Role: "tool",
|
||||
Content: content,
|
||||
ToolCallID: block.ToolUseID,
|
||||
})
|
||||
|
||||
case "image":
|
||||
// MVP 不支持多模态
|
||||
return nil, fmt.Errorf("MVP 不支持多模态内容(图片)")
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("未知的内容块类型: %s", block.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有 content,创建空消息(不应该发生)
|
||||
if len(messages) == 0 {
|
||||
messages = append(messages, openai.Message{
|
||||
Role: msg.Role,
|
||||
Content: "",
|
||||
})
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// convertToolChoice 转换工具选择
|
||||
func convertToolChoice(choice interface{}) (interface{}, error) {
|
||||
// 如果是字符串
|
||||
if str, ok := choice.(string); ok {
|
||||
// "auto" 或 "any" 都映射为 "auto"
|
||||
if str == "auto" || str == "any" {
|
||||
return "auto", nil
|
||||
}
|
||||
return nil, fmt.Errorf("无效的 tool_choice 字符串: %s", str)
|
||||
}
|
||||
|
||||
// 如果是对象
|
||||
if obj, ok := choice.(map[string]interface{}); ok {
|
||||
choiceType, ok := obj["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tool_choice 对象缺少 type 字段")
|
||||
}
|
||||
|
||||
switch choiceType {
|
||||
case "auto", "any":
|
||||
return "auto", nil
|
||||
case "tool":
|
||||
name, ok := obj["name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tool_choice type=tool 缺少 name 字段")
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"type": "function",
|
||||
"function": map[string]string{
|
||||
"name": name,
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("无效的 tool_choice type: %s", choiceType)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("tool_choice 格式无效")
|
||||
}
|
||||
164
backend/internal/protocol/anthropic/stream_converter.go
Normal file
164
backend/internal/protocol/anthropic/stream_converter.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"nex/backend/internal/protocol/openai"
|
||||
)
|
||||
|
||||
// StreamConverter 流式转换器
|
||||
type StreamConverter struct {
|
||||
messageID string
|
||||
model string
|
||||
index int // 当前 content block index
|
||||
toolCallArgs map[int]string // 缓存每个 tool_call 的 arguments
|
||||
sentStart bool // 是否已发送 message_start
|
||||
sentBlockStart map[int]bool // 每个 index 是否已发送 content_block_start
|
||||
}
|
||||
|
||||
// NewStreamConverter 创建流式转换器
|
||||
func NewStreamConverter(messageID, model string) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
messageID: messageID,
|
||||
model: model,
|
||||
index: 0,
|
||||
toolCallArgs: make(map[int]string),
|
||||
sentStart: false,
|
||||
sentBlockStart: make(map[int]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertChunk 转换 OpenAI 流块为 Anthropic 事件
|
||||
func (c *StreamConverter) ConvertChunk(chunk *openai.StreamChunk) ([]StreamEvent, error) {
|
||||
var events []StreamEvent
|
||||
|
||||
// 发送 message_start(仅一次)
|
||||
if !c.sentStart {
|
||||
events = append(events, StreamEvent{
|
||||
Type: "message_start",
|
||||
Message: &MessagesResponse{
|
||||
ID: c.messageID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: c.model,
|
||||
Content: []ContentBlock{},
|
||||
Usage: Usage{
|
||||
InputTokens: 0,
|
||||
OutputTokens: 0,
|
||||
},
|
||||
},
|
||||
})
|
||||
c.sentStart = true
|
||||
}
|
||||
|
||||
// 处理每个 choice
|
||||
for _, choice := range chunk.Choices {
|
||||
// 处理 content delta
|
||||
if choice.Delta.Content != "" {
|
||||
// 发送 content_block_start(如果还没发送)
|
||||
if !c.sentBlockStart[c.index] {
|
||||
events = append(events, StreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.index,
|
||||
ContentBlock: &ContentBlock{
|
||||
Type: "text",
|
||||
},
|
||||
})
|
||||
c.sentBlockStart[c.index] = true
|
||||
}
|
||||
|
||||
// 发送 text delta
|
||||
events = append(events, StreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.index,
|
||||
Delta: &Delta{
|
||||
Type: "text_delta",
|
||||
Text: choice.Delta.Content,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// 处理 tool_calls delta
|
||||
if len(choice.Delta.ToolCalls) > 0 {
|
||||
for _, tc := range choice.Delta.ToolCalls {
|
||||
// 确定 tool_call index
|
||||
toolIndex := c.index + len(c.toolCallArgs)
|
||||
|
||||
// 发送 content_block_start(如果还没发送)
|
||||
if !c.sentBlockStart[toolIndex] {
|
||||
events = append(events, StreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: toolIndex,
|
||||
ContentBlock: &ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
},
|
||||
})
|
||||
c.sentBlockStart[toolIndex] = true
|
||||
c.toolCallArgs[toolIndex] = ""
|
||||
}
|
||||
|
||||
// 缓存 arguments
|
||||
c.toolCallArgs[toolIndex] += tc.Function.Arguments
|
||||
|
||||
// 发送 input delta
|
||||
events = append(events, StreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: toolIndex,
|
||||
Delta: &Delta{
|
||||
Type: "input_json_delta",
|
||||
Input: tc.Function.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 finish_reason
|
||||
if choice.FinishReason != "" {
|
||||
// 发送 content_block_stop
|
||||
for idx := range c.sentBlockStart {
|
||||
events = append(events, StreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: idx,
|
||||
})
|
||||
}
|
||||
|
||||
// 转换 stop_reason
|
||||
stopReason := ""
|
||||
switch choice.FinishReason {
|
||||
case "stop":
|
||||
stopReason = "end_turn"
|
||||
case "tool_calls":
|
||||
stopReason = "tool_use"
|
||||
case "length":
|
||||
stopReason = "max_tokens"
|
||||
}
|
||||
|
||||
// 发送 message_delta
|
||||
events = append(events, StreamEvent{
|
||||
Type: "message_delta",
|
||||
Delta: &Delta{
|
||||
StopReason: stopReason,
|
||||
},
|
||||
})
|
||||
|
||||
// 发送 message_stop
|
||||
events = append(events, StreamEvent{
|
||||
Type: "message_stop",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// SerializeEvent 序列化事件为 SSE 格式
|
||||
func SerializeEvent(event StreamEvent) (string, error) {
|
||||
bytes, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("event: %s\ndata: %s\n\n", event.Type, string(bytes)), nil
|
||||
}
|
||||
118
backend/internal/protocol/anthropic/types.go
Normal file
118
backend/internal/protocol/anthropic/types.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package anthropic
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// MessagesRequest Anthropic Messages API 请求结构
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []AnthropicMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []AnthropicTool `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"` // 可以是字符串或对象
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicMessage Anthropic 消息结构
|
||||
type AnthropicMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
}
|
||||
|
||||
// ContentBlock 内容块
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"` // "text", "image", "tool_use", "tool_result"
|
||||
Text string `json:"text,omitempty"`
|
||||
Input interface{} `json:"input,omitempty"` // 用于 tool_use
|
||||
|
||||
// tool_use 字段
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
|
||||
// tool_result 字段
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content interface{} `json:"content,omitempty"` // 可以是字符串或数组
|
||||
|
||||
// 多模态字段(MVP 不支持)
|
||||
Source interface{} `json:"source,omitempty"` // 用于 image
|
||||
}
|
||||
|
||||
// AnthropicTool Anthropic 工具定义
|
||||
type AnthropicTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]interface{} `json:"input_schema"`
|
||||
}
|
||||
|
||||
// ToolChoice 工具选择
|
||||
type ToolChoice struct {
|
||||
Type string `json:"type"` // "auto", "any", "tool"
|
||||
Name string `json:"name,omitempty"` // 当 type="tool" 时使用
|
||||
}
|
||||
|
||||
// MessagesResponse Anthropic Messages API 响应结构
|
||||
type MessagesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "message"
|
||||
Role string `json:"role"` // "assistant"
|
||||
Content []ContentBlock `json:"content"`
|
||||
Model string `json:"model"`
|
||||
StopReason string `json:"stop_reason,omitempty"` // "end_turn", "max_tokens", "stop_sequence", "tool_use"
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// Usage 使用统计
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// StreamEvent 流式事件
|
||||
type StreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Message *MessagesResponse `json:"message,omitempty"` // 用于 message_start
|
||||
Index int `json:"index,omitempty"` // 用于 content_block_* 事件
|
||||
ContentBlock *ContentBlock `json:"content_block,omitempty"` // 用于 content_block_start
|
||||
Delta *Delta `json:"delta,omitempty"` // 用于 content_block_delta
|
||||
}
|
||||
|
||||
// Delta 增量内容
|
||||
type Delta struct {
|
||||
Type string `json:"type,omitempty"` // "text_delta", "input_json_delta"
|
||||
Text string `json:"text,omitempty"`
|
||||
Input string `json:"input,omitempty"` // 用于 tool_use 的部分 JSON
|
||||
StopReason string `json:"stop_reason,omitempty"` // 用于 message_delta
|
||||
Usage *Usage `json:"usage,omitempty"` // 用于 message_delta
|
||||
}
|
||||
|
||||
// ErrorResponse Anthropic 错误响应
|
||||
type ErrorResponse struct {
|
||||
Type string `json:"type"` // "error"
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail 错误详情
|
||||
type ErrorDetail struct {
|
||||
Type string `json:"type"` // "invalid_request_error", "authentication_error", etc.
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ParseInputJSON 解析 tool_use 的 input(从 JSON 字符串转为 map)
|
||||
func (cb *ContentBlock) ParseInputJSON() (map[string]interface{}, error) {
|
||||
if str, ok := cb.Input.(string); ok {
|
||||
var result map[string]interface{}
|
||||
err := json.Unmarshal([]byte(str), &result)
|
||||
return result, err
|
||||
}
|
||||
// 如果已经是对象,直接返回
|
||||
if obj, ok := cb.Input.(map[string]interface{}); ok {
|
||||
return obj, nil
|
||||
}
|
||||
return nil, json.Unmarshal([]byte{}, nil) // 返回错误
|
||||
}
|
||||
86
backend/internal/protocol/openai/adapter.go
Normal file
86
backend/internal/protocol/openai/adapter.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Adapter OpenAI 协议适配器(透传)
|
||||
type Adapter struct{}
|
||||
|
||||
// NewAdapter 创建 OpenAI 适配器
|
||||
func NewAdapter() *Adapter {
|
||||
return &Adapter{}
|
||||
}
|
||||
|
||||
// PrepareRequest 准备发送给供应商的请求(透传)
|
||||
func (a *Adapter) PrepareRequest(req *ChatCompletionRequest, apiKey, baseURL string) (*http.Request, error) {
|
||||
// 序列化请求体
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调试日志:打印请求体
|
||||
fmt.Printf("[DEBUG] 请求Body: %s\n", string(body))
|
||||
|
||||
// 创建 HTTP 请求
|
||||
// baseURL 已包含版本路径(如 /v1 或 /v4),只需添加端点路径
|
||||
httpReq, err := http.NewRequest("POST", baseURL+"/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
return httpReq, nil
|
||||
}
|
||||
|
||||
// ParseResponse 解析供应商响应(透传)
|
||||
func (a *Adapter) ParseResponse(resp *http.Response) (*ChatCompletionResponse, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result ChatCompletionResponse
|
||||
err = json.Unmarshal(body, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ParseErrorResponse 解析错误响应
|
||||
func (a *Adapter) ParseErrorResponse(resp *http.Response) (*ErrorResponse, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result ErrorResponse
|
||||
err = json.Unmarshal(body, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ParseStreamChunk 解析流式响应块
|
||||
func (a *Adapter) ParseStreamChunk(data []byte) (*StreamChunk, error) {
|
||||
var chunk StreamChunk
|
||||
err := json.Unmarshal(data, &chunk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &chunk, nil
|
||||
}
|
||||
131
backend/internal/protocol/openai/types.go
Normal file
131
backend/internal/protocol/openai/types.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package openai
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// ChatCompletionRequest OpenAI Chat Completions API 请求结构
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Stop interface{} `json:"stop,omitempty"` // 可以是字符串或字符串数组
|
||||
N *int `json:"n,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"` // 可以是字符串或对象
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// Message OpenAI 消息结构
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content interface{} `json:"content"` // 可以是字符串或数组(多模态,MVP不支持)
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"` // 用于 role="tool" 的消息
|
||||
}
|
||||
|
||||
// Tool OpenAI 工具定义
|
||||
type Tool struct {
|
||||
Type string `json:"type"` // 目前只有 "function"
|
||||
Function FunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
// FunctionDefinition 函数定义
|
||||
type FunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters map[string]interface{} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCall 工具调用
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "function"
|
||||
Function FunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
// FunctionCall 函数调用
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"` // JSON 字符串
|
||||
}
|
||||
|
||||
// ChatCompletionResponse OpenAI Chat Completions API 响应结构
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// Choice 响应选项
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Delta *Delta `json:"delta,omitempty"` // 用于流式响应
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// Delta 流式响应增量
|
||||
type Delta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// Usage Token 使用统计
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// StreamChunk 流式响应块
|
||||
type StreamChunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []StreamChoice `json:"choices"`
|
||||
}
|
||||
|
||||
// StreamChoice 流式响应选项
|
||||
type StreamChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta Delta `json:"delta"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorResponse OpenAI 错误响应
|
||||
type ErrorResponse struct {
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail 错误详情
|
||||
type ErrorDetail struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
// ParseToolCallArguments 解析 tool_call 的 arguments(从 JSON 字符串转为 map)
|
||||
func (tc *ToolCall) ParseToolCallArguments() (map[string]interface{}, error) {
|
||||
var args map[string]interface{}
|
||||
err := json.Unmarshal([]byte(tc.Function.Arguments), &args)
|
||||
return args, err
|
||||
}
|
||||
|
||||
// SerializeToolCallArguments 序列化 tool_call 的 arguments(从 map 转为 JSON 字符串)
|
||||
func SerializeToolCallArguments(args map[string]interface{}) (string, error) {
|
||||
bytes, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(bytes), nil
|
||||
}
|
||||
177
backend/internal/provider/client.go
Normal file
177
backend/internal/provider/client.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/protocol/openai"
|
||||
)
|
||||
|
||||
// Client OpenAI 兼容供应商客户端
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
adapter *openai.Adapter
|
||||
}
|
||||
|
||||
// NewClient 创建供应商客户端
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second, // 非流式请求超时
|
||||
},
|
||||
adapter: openai.NewAdapter(),
|
||||
}
|
||||
}
|
||||
|
||||
// SendRequest 发送非流式请求
|
||||
func (c *Client) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) {
|
||||
// 准备请求
|
||||
httpReq, err := c.adapter.PrepareRequest(req, apiKey, baseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("准备请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 调试日志:打印完整请求信息
|
||||
fmt.Printf("[DEBUG] 请求URL: %s\n", httpReq.URL.String())
|
||||
fmt.Printf("[DEBUG] 请求Method: %s\n", httpReq.Method)
|
||||
fmt.Printf("[DEBUG] 请求Headers: %v\n", httpReq.Header)
|
||||
|
||||
// 设置上下文
|
||||
httpReq = httpReq.WithContext(ctx)
|
||||
|
||||
// 发送请求
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("发送请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// 解析错误响应
|
||||
errorResp, parseErr := c.adapter.ParseErrorResponse(resp)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("供应商错误: %s", errorResp.Error.Message)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
result, err := c.adapter.ParseResponse(resp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SendStreamRequest 发送流式请求
|
||||
func (c *Client) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan StreamEvent, error) {
|
||||
// 确保请求设置为流式
|
||||
req.Stream = true
|
||||
|
||||
// 准备请求
|
||||
httpReq, err := c.adapter.PrepareRequest(req, apiKey, baseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("准备请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置上下文
|
||||
httpReq = httpReq.WithContext(ctx)
|
||||
|
||||
// 发送请求
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("发送请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
errorResp, parseErr := c.adapter.ParseErrorResponse(resp)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("供应商错误: %s", errorResp.Error.Message)
|
||||
}
|
||||
|
||||
// 创建事件通道
|
||||
eventChan := make(chan StreamEvent, 100)
|
||||
|
||||
// 启动 goroutine 读取流
|
||||
go c.readStream(ctx, resp.Body, eventChan)
|
||||
|
||||
return eventChan, nil
|
||||
}
|
||||
|
||||
// StreamEvent 流事件
|
||||
type StreamEvent struct {
|
||||
Data []byte
|
||||
Error error
|
||||
Done bool
|
||||
}
|
||||
|
||||
// readStream 读取 SSE 流
|
||||
func (c *Client) readStream(ctx context.Context, body io.ReadCloser, eventChan chan<- StreamEvent) {
|
||||
defer close(eventChan)
|
||||
defer body.Close()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
var dataBuf []byte
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
eventChan <- StreamEvent{Error: ctx.Err()}
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
n, err := body.Read(buf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
// 流结束
|
||||
return
|
||||
}
|
||||
eventChan <- StreamEvent{Error: err}
|
||||
return
|
||||
}
|
||||
|
||||
dataBuf = append(dataBuf, buf[:n]...)
|
||||
|
||||
// 处理完整的 SSE 事件
|
||||
for {
|
||||
// 查找事件边界(双换行)
|
||||
idx := bytes.Index(dataBuf, []byte("\n\n"))
|
||||
if idx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
// 提取事件
|
||||
event := dataBuf[:idx]
|
||||
dataBuf = dataBuf[idx+2:]
|
||||
|
||||
// 解析 data 行
|
||||
lines := strings.Split(string(event), "\n")
|
||||
for _, line := range lines {
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
// 检查是否是结束标记
|
||||
if data == "[DONE]" {
|
||||
eventChan <- StreamEvent{Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
// 发送数据
|
||||
eventChan <- StreamEvent{Data: []byte(data)}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
71
backend/internal/router/model_router.go
Normal file
71
backend/internal/router/model_router.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrModelNotFound = errors.New("模型未找到")
|
||||
ErrModelDisabled = errors.New("模型已禁用")
|
||||
ErrProviderDisabled = errors.New("供应商已禁用")
|
||||
)
|
||||
|
||||
// RouteResult 路由结果
|
||||
type RouteResult struct {
|
||||
Provider *config.Provider
|
||||
Model *config.Model
|
||||
}
|
||||
|
||||
// Router 模型路由器
|
||||
type Router struct{}
|
||||
|
||||
// NewRouter 创建路由器
|
||||
func NewRouter() *Router {
|
||||
return &Router{}
|
||||
}
|
||||
|
||||
// Route 根据模型名称路由到供应商
|
||||
func (r *Router) Route(modelName string) (*RouteResult, error) {
|
||||
// 查询模型
|
||||
models, err := config.ListModels("")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询模型失败: %w", err)
|
||||
}
|
||||
|
||||
// 查找匹配的模型
|
||||
var targetModel *config.Model
|
||||
for i := range models {
|
||||
if models[i].ModelName == modelName {
|
||||
targetModel = &models[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if targetModel == nil {
|
||||
return nil, ErrModelNotFound
|
||||
}
|
||||
|
||||
// 检查模型是否启用
|
||||
if !targetModel.Enabled {
|
||||
return nil, ErrModelDisabled
|
||||
}
|
||||
|
||||
// 查询供应商
|
||||
provider, err := config.GetProvider(targetModel.ProviderID, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询供应商失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查供应商是否启用
|
||||
if !provider.Enabled {
|
||||
return nil, ErrProviderDisabled
|
||||
}
|
||||
|
||||
return &RouteResult{
|
||||
Provider: provider,
|
||||
Model: targetModel,
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user