refactor: 实现 ConversionEngine 协议转换引擎,替代旧 protocol 包
引入 Canonical Model 和 ProtocolAdapter 架构,支持 OpenAI/Anthropic 协议间 无缝转换,统一 ProxyHandler 替代分散的 OpenAI/Anthropic Handler,简化 ProviderClient 为协议无关的 HTTP 发送器,Provider 新增 protocol 字段。
This commit is contained in:
@@ -4,10 +4,14 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 支持 OpenAI 协议(`/v1/chat/completions`)
|
||||
- 支持 Anthropic 协议(`/v1/messages`)
|
||||
- 支持 OpenAI 协议(`/openai/v1/...`)
|
||||
- 支持 Anthropic 协议(`/anthropic/v1/...`)
|
||||
- 支持 Hub-and-Spoke 跨协议双向转换(OpenAI ↔ Anthropic)
|
||||
- 同协议透传(零语义损失、零序列化开销)
|
||||
- 支持流式响应(SSE)
|
||||
- 支持 Function Calling / Tools
|
||||
- 支持 Thinking / Reasoning
|
||||
- 支持扩展层接口(Models、Embeddings、Rerank)
|
||||
- 多供应商配置和路由
|
||||
- 用量统计
|
||||
- 结构化日志(zap + lumberjack)
|
||||
@@ -48,19 +52,36 @@ backend/
|
||||
│ │ │ ├── logging.go
|
||||
│ │ │ ├── recovery.go
|
||||
│ │ │ └── cors.go
|
||||
│ │ ├── openai_handler.go
|
||||
│ │ ├── anthropic_handler.go
|
||||
│ │ ├── proxy_handler.go # 统一代理处理器
|
||||
│ │ ├── provider_handler.go
|
||||
│ │ ├── model_handler.go
|
||||
│ │ └── stats_handler.go
|
||||
│ ├── protocol/ # 协议适配器
|
||||
│ │ ├── openai/
|
||||
│ │ │ ├── types.go # 请求/响应类型 + 验证
|
||||
│ │ │ └── adapter.go # OpenAI 协议适配
|
||||
│ │ └── anthropic/
|
||||
│ │ ├── types.go # 请求/响应类型 + 验证
|
||||
│ │ ├── converter.go # 协议转换
|
||||
│ │ └── stream_converter.go # 流式转换
|
||||
│ ├── conversion/ # 协议转换引擎
|
||||
│ │ ├── canonical/ # Canonical Model
|
||||
│ │ │ ├── types.go # 核心请求/响应类型
|
||||
│ │ │ ├── stream.go # 流式事件类型
|
||||
│ │ │ └── extended.go # 扩展层 Models
|
||||
│ │ ├── openai/ # OpenAI 协议适配器
|
||||
│ │ │ ├── types.go
|
||||
│ │ │ ├── adapter.go
|
||||
│ │ │ ├── decoder.go
|
||||
│ │ │ ├── encoder.go
|
||||
│ │ │ ├── stream_decoder.go
|
||||
│ │ │ └── stream_encoder.go
|
||||
│ │ ├── anthropic/ # Anthropic 协议适配器
|
||||
│ │ │ ├── types.go
|
||||
│ │ │ ├── adapter.go
|
||||
│ │ │ ├── decoder.go
|
||||
│ │ │ ├── encoder.go
|
||||
│ │ │ ├── stream_decoder.go
|
||||
│ │ │ └── stream_encoder.go
|
||||
│ │ ├── adapter.go # ProtocolAdapter 接口 + Registry
|
||||
│ │ ├── stream.go # StreamDecoder/Encoder/Converter
|
||||
│ │ ├── middleware.go # Middleware 接口和 Chain
|
||||
│ │ ├── engine.go # ConversionEngine 门面
|
||||
│ │ ├── errors.go # ConversionError
|
||||
│ │ ├── interface.go # InterfaceType 枚举
|
||||
│ │ └── provider.go # TargetProvider
|
||||
│ ├── provider/ # 供应商客户端
|
||||
│ │ └── client.go
|
||||
│ ├── repository/ # 数据访问层
|
||||
@@ -184,10 +205,15 @@ goose -dir migrations sqlite3 ~/.nex/config.db up
|
||||
|
||||
### 代理接口
|
||||
|
||||
#### OpenAI Chat Completions
|
||||
使用 `/{protocol}/v1/{path}` URL 前缀路由:
|
||||
|
||||
#### OpenAI 协议代理
|
||||
|
||||
```
|
||||
POST /v1/chat/completions
|
||||
POST /openai/v1/chat/completions
|
||||
GET /openai/v1/models
|
||||
POST /openai/v1/embeddings
|
||||
POST /openai/v1/rerank
|
||||
```
|
||||
|
||||
请求示例:
|
||||
@@ -202,10 +228,11 @@ POST /v1/chat/completions
|
||||
}
|
||||
```
|
||||
|
||||
#### Anthropic Messages
|
||||
#### Anthropic 协议代理
|
||||
|
||||
```
|
||||
POST /v1/messages
|
||||
POST /anthropic/v1/messages
|
||||
GET /anthropic/v1/models
|
||||
```
|
||||
|
||||
请求示例:
|
||||
@@ -220,6 +247,8 @@ POST /v1/messages
|
||||
}
|
||||
```
|
||||
|
||||
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销。
|
||||
|
||||
### 管理接口
|
||||
|
||||
#### 供应商管理
|
||||
@@ -237,10 +266,15 @@ POST /v1/messages
|
||||
"id": "openai",
|
||||
"name": "OpenAI",
|
||||
"api_key": "sk-...",
|
||||
"base_url": "https://api.openai.com/v1"
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"protocol": "openai"
|
||||
}
|
||||
```
|
||||
|
||||
**Protocol 字段说明:**
|
||||
- `protocol` 标识上游供应商使用的协议类型,可选值:`"openai"`(默认)、`"anthropic"`
|
||||
- 同协议透传时,请求体和响应体原样转发,零序列化开销
|
||||
|
||||
**重要说明:**
|
||||
- `base_url` 应配置到 API 版本路径,不包含具体端点
|
||||
- OpenAI: `https://api.openai.com/v1`
|
||||
|
||||
@@ -20,6 +20,9 @@ import (
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/conversion/openai"
|
||||
"nex/backend/internal/handler"
|
||||
"nex/backend/internal/handler/middleware"
|
||||
"nex/backend/internal/provider"
|
||||
@@ -70,30 +73,37 @@ func main() {
|
||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
|
||||
// 6. 初始化 provider client
|
||||
// 6. 创建 ConversionEngine
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
|
||||
}
|
||||
if err := registry.Register(anthropic.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error()))
|
||||
}
|
||||
engine := conversion.NewConversionEngine(registry)
|
||||
|
||||
// 7. 初始化 provider client
|
||||
providerClient := provider.NewClient()
|
||||
|
||||
// 7. 初始化 handler 层
|
||||
openaiHandler := handler.NewOpenAIHandler(providerClient, routingService, statsService)
|
||||
anthropicHandler := handler.NewAnthropicHandler(providerClient, routingService, statsService)
|
||||
// 8. 初始化 handler 层
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
statsHandler := handler.NewStatsHandler(statsService)
|
||||
|
||||
// 8. 创建 Gin 引擎
|
||||
// 9. 创建 Gin 引擎
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.New()
|
||||
|
||||
// 注册中间件(按正确顺序)
|
||||
r.Use(middleware.RequestID())
|
||||
r.Use(middleware.Recovery(zapLogger))
|
||||
r.Use(middleware.Logging(zapLogger))
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
// 注册路由
|
||||
setupRoutes(r, openaiHandler, anthropicHandler, providerHandler, modelHandler, statsHandler)
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
|
||||
|
||||
// 9. 启动服务器
|
||||
// 10. 启动服务器
|
||||
srv := &http.Server{
|
||||
Addr: formatAddr(cfg.Server.Port),
|
||||
Handler: r,
|
||||
@@ -108,7 +118,6 @@ func main() {
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待中断信号
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
@@ -137,12 +146,10 @@ func initDatabase(cfg *config.Config) (*gorm.DB, error) {
|
||||
log.Printf("警告: 启用 WAL 模式失败: %v", err)
|
||||
}
|
||||
|
||||
// 运行数据库迁移
|
||||
if err := runMigrations(db); err != nil {
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
// 配置连接池
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -151,14 +158,12 @@ func initDatabase(cfg *config.Config) (*gorm.DB, error) {
|
||||
sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime)
|
||||
|
||||
// 记录连接池状态
|
||||
log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v",
|
||||
cfg.Database.MaxIdleConns, cfg.Database.MaxOpenConns, cfg.Database.ConnMaxLifetime)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// runMigrations 使用 goose 执行数据库迁移
|
||||
func runMigrations(db *gorm.DB) error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
@@ -178,18 +183,14 @@ func runMigrations(db *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getMigrationsDir 获取迁移文件目录路径
|
||||
func getMigrationsDir() string {
|
||||
// 从可执行文件位置推断迁移目录
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if ok {
|
||||
// cmd/server/main.go → backend/ → backend/migrations/
|
||||
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations")
|
||||
if abs, err := filepath.Abs(dir); err == nil {
|
||||
return abs
|
||||
}
|
||||
}
|
||||
// 回退到相对路径
|
||||
return "./migrations"
|
||||
}
|
||||
|
||||
@@ -205,12 +206,9 @@ func formatAddr(port int) string {
|
||||
return fmt.Sprintf(":%d", port)
|
||||
}
|
||||
|
||||
func setupRoutes(r *gin.Engine, openaiHandler *handler.OpenAIHandler, anthropicHandler *handler.AnthropicHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
|
||||
// OpenAI 协议代理
|
||||
r.POST("/v1/chat/completions", openaiHandler.HandleChatCompletions)
|
||||
|
||||
// Anthropic 协议代理
|
||||
r.POST("/v1/messages", anthropicHandler.HandleMessages)
|
||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
|
||||
// 统一代理入口: /{protocol}/v1/{path}
|
||||
r.Any("/:protocol/v1/*path", proxyHandler.HandleProxy)
|
||||
|
||||
// 供应商管理 API
|
||||
providers := r.Group("/api/providers")
|
||||
|
||||
@@ -10,6 +10,7 @@ type Provider struct {
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
APIKey string `gorm:"not null" json:"api_key"`
|
||||
BaseURL string `gorm:"not null" json:"base_url"`
|
||||
Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
100
backend/internal/conversion/adapter.go
Normal file
100
backend/internal/conversion/adapter.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// ProtocolAdapter 协议适配器接口
|
||||
type ProtocolAdapter interface {
|
||||
ProtocolName() string
|
||||
ProtocolVersion() string
|
||||
SupportsPassthrough() bool
|
||||
|
||||
DetectInterfaceType(nativePath string) InterfaceType
|
||||
BuildUrl(nativePath string, interfaceType InterfaceType) string
|
||||
BuildHeaders(provider *TargetProvider) map[string]string
|
||||
SupportsInterface(interfaceType InterfaceType) bool
|
||||
|
||||
DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error)
|
||||
EncodeRequest(req *canonical.CanonicalRequest, provider *TargetProvider) ([]byte, error)
|
||||
DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error)
|
||||
EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error)
|
||||
|
||||
CreateStreamDecoder() StreamDecoder
|
||||
CreateStreamEncoder() StreamEncoder
|
||||
|
||||
EncodeError(err *ConversionError) ([]byte, int)
|
||||
|
||||
DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error)
|
||||
EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error)
|
||||
DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error)
|
||||
EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error)
|
||||
DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error)
|
||||
EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *TargetProvider) ([]byte, error)
|
||||
DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error)
|
||||
EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error)
|
||||
DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error)
|
||||
EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error)
|
||||
DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error)
|
||||
EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error)
|
||||
}
|
||||
|
||||
// AdapterRegistry 适配器注册表接口
|
||||
type AdapterRegistry interface {
|
||||
Register(adapter ProtocolAdapter) error
|
||||
Get(protocolName string) (ProtocolAdapter, error)
|
||||
ListProtocols() []string
|
||||
}
|
||||
|
||||
// memoryRegistry 基于内存的适配器注册表
|
||||
type memoryRegistry struct {
|
||||
mu sync.RWMutex
|
||||
adapters map[string]ProtocolAdapter
|
||||
}
|
||||
|
||||
// NewMemoryRegistry 创建内存注册表
|
||||
func NewMemoryRegistry() AdapterRegistry {
|
||||
return &memoryRegistry{
|
||||
adapters: make(map[string]ProtocolAdapter),
|
||||
}
|
||||
}
|
||||
|
||||
// Register 注册适配器
|
||||
func (r *memoryRegistry) Register(adapter ProtocolAdapter) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
name := adapter.ProtocolName()
|
||||
if _, exists := r.adapters[name]; exists {
|
||||
return fmt.Errorf("适配器已注册: %s", name)
|
||||
}
|
||||
r.adapters[name] = adapter
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get 获取适配器
|
||||
func (r *memoryRegistry) Get(protocolName string) (ProtocolAdapter, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
adapter, ok := r.adapters[protocolName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("未找到适配器: %s", protocolName)
|
||||
}
|
||||
return adapter, nil
|
||||
}
|
||||
|
||||
// ListProtocols 列出所有已注册协议
|
||||
func (r *memoryRegistry) ListProtocols() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
protocols := make([]string, 0, len(r.adapters))
|
||||
for name := range r.adapters {
|
||||
protocols = append(protocols, name)
|
||||
}
|
||||
return protocols
|
||||
}
|
||||
199
backend/internal/conversion/anthropic/adapter.go
Normal file
199
backend/internal/conversion/anthropic/adapter.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// Adapter Anthropic 协议适配器
|
||||
type Adapter struct{}
|
||||
|
||||
// NewAdapter 创建 Anthropic 适配器
|
||||
func NewAdapter() *Adapter {
|
||||
return &Adapter{}
|
||||
}
|
||||
|
||||
var modelInfoRegex = regexp.MustCompile(`^/v1/models/[^/]+$`)
|
||||
|
||||
// ProtocolName 返回协议名称
|
||||
func (a *Adapter) ProtocolName() string { return "anthropic" }
|
||||
|
||||
// ProtocolVersion 返回协议版本
|
||||
func (a *Adapter) ProtocolVersion() string { return "2023-06-01" }
|
||||
|
||||
// SupportsPassthrough 支持同协议透传
|
||||
func (a *Adapter) SupportsPassthrough() bool { return true }
|
||||
|
||||
// DetectInterfaceType 根据路径检测接口类型
|
||||
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
|
||||
switch {
|
||||
case nativePath == "/v1/messages":
|
||||
return conversion.InterfaceTypeChat
|
||||
case nativePath == "/v1/models":
|
||||
return conversion.InterfaceTypeModels
|
||||
case modelInfoRegex.MatchString(nativePath):
|
||||
return conversion.InterfaceTypeModelInfo
|
||||
default:
|
||||
return conversion.InterfaceTypePassthrough
|
||||
}
|
||||
}
|
||||
|
||||
// BuildUrl 根据接口类型构建 URL
|
||||
func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string {
|
||||
switch interfaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
return "/v1/messages"
|
||||
case conversion.InterfaceTypeModels:
|
||||
return "/v1/models"
|
||||
default:
|
||||
return nativePath
|
||||
}
|
||||
}
|
||||
|
||||
// BuildHeaders 构建请求头
|
||||
func (a *Adapter) BuildHeaders(provider *conversion.TargetProvider) map[string]string {
|
||||
headers := map[string]string{
|
||||
"x-api-key": provider.APIKey,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if v, ok := provider.AdapterConfig["anthropic_version"].(string); ok && v != "" {
|
||||
headers["anthropic-version"] = v
|
||||
}
|
||||
if betas, ok := provider.AdapterConfig["anthropic_beta"].([]string); ok && len(betas) > 0 {
|
||||
headers["anthropic-beta"] = strings.Join(betas, ",")
|
||||
} else if betas, ok := provider.AdapterConfig["anthropic_beta"].([]any); ok && len(betas) > 0 {
|
||||
var parts []string
|
||||
for _, b := range betas {
|
||||
if s, ok := b.(string); ok {
|
||||
parts = append(parts, s)
|
||||
}
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
headers["anthropic-beta"] = strings.Join(parts, ",")
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// SupportsInterface 检查是否支持接口类型
|
||||
func (a *Adapter) SupportsInterface(interfaceType conversion.InterfaceType) bool {
|
||||
switch interfaceType {
|
||||
case conversion.InterfaceTypeChat,
|
||||
conversion.InterfaceTypeModels,
|
||||
conversion.InterfaceTypeModelInfo:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeRequest 解码请求
|
||||
func (a *Adapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
return decodeRequest(raw)
|
||||
}
|
||||
|
||||
// EncodeRequest 编码请求
|
||||
func (a *Adapter) EncodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return encodeRequest(req, provider)
|
||||
}
|
||||
|
||||
// DecodeResponse 解码响应
|
||||
func (a *Adapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
return decodeResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeResponse 编码响应
|
||||
func (a *Adapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
return encodeResponse(resp)
|
||||
}
|
||||
|
||||
// CreateStreamDecoder 创建流式解码器
|
||||
func (a *Adapter) CreateStreamDecoder() conversion.StreamDecoder {
|
||||
return NewStreamDecoder()
|
||||
}
|
||||
|
||||
// CreateStreamEncoder 创建流式编码器
|
||||
func (a *Adapter) CreateStreamEncoder() conversion.StreamEncoder {
|
||||
return NewStreamEncoder()
|
||||
}
|
||||
|
||||
// EncodeError 编码错误
|
||||
func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
||||
errType := string(err.Code)
|
||||
statusCode := 500
|
||||
|
||||
errMsg := ErrorResponse{
|
||||
Type: "error",
|
||||
Error: ErrorDetail{
|
||||
Type: errType,
|
||||
Message: err.Message,
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(errMsg)
|
||||
return body, statusCode
|
||||
}
|
||||
|
||||
// DecodeModelsResponse 解码模型列表响应
|
||||
func (a *Adapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) {
|
||||
return decodeModelsResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeModelsResponse 编码模型列表响应
|
||||
func (a *Adapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
|
||||
return encodeModelsResponse(list)
|
||||
}
|
||||
|
||||
// DecodeModelInfoResponse 解码模型详情响应
|
||||
func (a *Adapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) {
|
||||
return decodeModelInfoResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeModelInfoResponse 编码模型详情响应
|
||||
func (a *Adapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
|
||||
return encodeModelInfoResponse(info)
|
||||
}
|
||||
|
||||
// DecodeEmbeddingRequest Anthropic 不支持嵌入
|
||||
func (a *Adapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口")
|
||||
}
|
||||
|
||||
// EncodeEmbeddingRequest Anthropic 不支持嵌入
|
||||
func (a *Adapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口")
|
||||
}
|
||||
|
||||
// DecodeEmbeddingResponse Anthropic 不支持嵌入
|
||||
func (a *Adapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口")
|
||||
}
|
||||
|
||||
// EncodeEmbeddingResponse Anthropic 不支持嵌入
|
||||
func (a *Adapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口")
|
||||
}
|
||||
|
||||
// DecodeRerankRequest Anthropic 不支持重排序
|
||||
func (a *Adapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||
}
|
||||
|
||||
// EncodeRerankRequest Anthropic 不支持重排序
|
||||
func (a *Adapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||
}
|
||||
|
||||
// DecodeRerankResponse Anthropic 不支持重排序
|
||||
func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||
}
|
||||
|
||||
// EncodeRerankResponse Anthropic 不支持重排序
|
||||
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||
}
|
||||
210
backend/internal/conversion/anthropic/adapter_test.go
Normal file
210
backend/internal/conversion/anthropic/adapter_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAdapter_ProtocolName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
assert.Equal(t, "anthropic", a.ProtocolName())
|
||||
}
|
||||
|
||||
func TestAdapter_ProtocolVersion(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
assert.Equal(t, "2023-06-01", a.ProtocolVersion())
|
||||
}
|
||||
|
||||
func TestAdapter_SupportsPassthrough(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
assert.True(t, a.SupportsPassthrough())
|
||||
}
|
||||
|
||||
func TestAdapter_DetectInterfaceType(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"聊天消息", "/v1/messages", conversion.InterfaceTypeChat},
|
||||
{"模型列表", "/v1/models", conversion.InterfaceTypeModels},
|
||||
{"模型详情", "/v1/models/claude-3", conversion.InterfaceTypeModelInfo},
|
||||
{"未知路径", "/v1/unknown", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.DetectInterfaceType(tt.path)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nativePath string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected string
|
||||
}{
|
||||
{"聊天", "/v1/messages", conversion.InterfaceTypeChat, "/v1/messages"},
|
||||
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
|
||||
{"默认透传", "/v1/other", conversion.InterfaceTypePassthrough, "/v1/other"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.BuildUrl(tt.nativePath, tt.interfaceType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildHeaders_Basic(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "sk-ant-test", "claude-3")
|
||||
|
||||
headers := a.BuildHeaders(provider)
|
||||
assert.Equal(t, "sk-ant-test", headers["x-api-key"])
|
||||
assert.Equal(t, "2023-06-01", headers["anthropic-version"])
|
||||
assert.Equal(t, "application/json", headers["Content-Type"])
|
||||
}
|
||||
|
||||
func TestAdapter_BuildHeaders_CustomVersion(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "sk-ant-test", "claude-3")
|
||||
provider.AdapterConfig["anthropic_version"] = "2024-01-01"
|
||||
|
||||
headers := a.BuildHeaders(provider)
|
||||
assert.Equal(t, "2024-01-01", headers["anthropic-version"])
|
||||
}
|
||||
|
||||
func TestAdapter_BuildHeaders_AnthropicBeta(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "sk-ant-test", "claude-3")
|
||||
provider.AdapterConfig["anthropic_beta"] = []string{"prompt-caching-2024-07-31", "max-tokens-3-5-sonnet-2024-07-15"}
|
||||
|
||||
headers := a.BuildHeaders(provider)
|
||||
assert.Equal(t, "prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15", headers["anthropic-beta"])
|
||||
}
|
||||
|
||||
func TestAdapter_SupportsInterface(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected bool
|
||||
}{
|
||||
{"聊天", conversion.InterfaceTypeChat, true},
|
||||
{"模型", conversion.InterfaceTypeModels, true},
|
||||
{"模型详情", conversion.InterfaceTypeModelInfo, true},
|
||||
{"嵌入", conversion.InterfaceTypeEmbeddings, false},
|
||||
{"重排序", conversion.InterfaceTypeRerank, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.SupportsInterface(tt.interfaceType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_EncodeError(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")
|
||||
|
||||
body, statusCode := a.EncodeError(convErr)
|
||||
require.Equal(t, 500, statusCode)
|
||||
|
||||
var resp ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(body, &resp))
|
||||
assert.Equal(t, "error", resp.Type)
|
||||
assert.Equal(t, "INVALID_INPUT", resp.Error.Type)
|
||||
assert.Equal(t, "参数无效", resp.Error.Message)
|
||||
}
|
||||
|
||||
func TestAdapter_UnsupportedEmbedding(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("解码嵌入请求", func(t *testing.T) {
|
||||
_, err := a.DecodeEmbeddingRequest([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码嵌入请求", func(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("解码嵌入响应", func(t *testing.T) {
|
||||
_, err := a.DecodeEmbeddingResponse([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码嵌入响应", func(t *testing.T) {
|
||||
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdapter_UnsupportedRerank(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("解码重排序请求", func(t *testing.T) {
|
||||
_, err := a.DecodeRerankRequest([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码重排序请求", func(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("解码重排序响应", func(t *testing.T) {
|
||||
_, err := a.DecodeRerankResponse([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码重排序响应", func(t *testing.T) {
|
||||
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
}
|
||||
427
backend/internal/conversion/anthropic/decoder.go
Normal file
427
backend/internal/conversion/anthropic/decoder.go
Normal file
@@ -0,0 +1,427 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// decodeRequest 将 Anthropic 请求解码为 Canonical 请求
|
||||
func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
|
||||
var req MessagesRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 Anthropic 请求失败").WithCause(err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(req.Model) == "" {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "model 字段不能为空")
|
||||
}
|
||||
if len(req.Messages) == 0 {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "messages 字段不能为空")
|
||||
}
|
||||
|
||||
system := decodeSystem(req.System)
|
||||
|
||||
var canonicalMsgs []canonical.CanonicalMessage
|
||||
for _, msg := range req.Messages {
|
||||
decoded := decodeMessage(msg)
|
||||
canonicalMsgs = append(canonicalMsgs, decoded...)
|
||||
}
|
||||
|
||||
tools := decodeTools(req.Tools)
|
||||
toolChoice := decodeToolChoice(req.ToolChoice)
|
||||
params := decodeParameters(&req)
|
||||
thinking := decodeThinking(req.Thinking, req.OutputConfig)
|
||||
outputFormat := decodeOutputFormat(req.OutputConfig)
|
||||
|
||||
var parallelToolUse *bool
|
||||
if req.DisableParallelToolUse != nil && *req.DisableParallelToolUse {
|
||||
val := false
|
||||
parallelToolUse = &val
|
||||
}
|
||||
|
||||
var userID string
|
||||
if req.Metadata != nil {
|
||||
userID = req.Metadata.UserID
|
||||
}
|
||||
|
||||
return &canonical.CanonicalRequest{
|
||||
Model: req.Model,
|
||||
System: system,
|
||||
Messages: canonicalMsgs,
|
||||
Tools: tools,
|
||||
ToolChoice: toolChoice,
|
||||
Parameters: params,
|
||||
Thinking: thinking,
|
||||
Stream: req.Stream,
|
||||
UserID: userID,
|
||||
OutputFormat: outputFormat,
|
||||
ParallelToolUse: parallelToolUse,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeSystem 解码系统消息
|
||||
func decodeSystem(system any) any {
|
||||
if system == nil {
|
||||
return nil
|
||||
}
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
case []any:
|
||||
var blocks []canonical.SystemBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if text, ok := m["text"].(string); ok {
|
||||
blocks = append(blocks, canonical.SystemBlock{Text: text})
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
return nil
|
||||
}
|
||||
return blocks
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// decodeMessage 解码 Anthropic 消息
|
||||
func decodeMessage(msg Message) []canonical.CanonicalMessage {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
blocks := decodeContentBlocks(msg.Content)
|
||||
var toolResults []canonical.ContentBlock
|
||||
var others []canonical.ContentBlock
|
||||
for _, b := range blocks {
|
||||
if b.Type == "tool_result" {
|
||||
toolResults = append(toolResults, b)
|
||||
} else {
|
||||
others = append(others, b)
|
||||
}
|
||||
}
|
||||
var result []canonical.CanonicalMessage
|
||||
if len(others) > 0 {
|
||||
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: others})
|
||||
}
|
||||
if len(toolResults) > 0 {
|
||||
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleTool, Content: toolResults})
|
||||
}
|
||||
if len(result) == 0 {
|
||||
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
|
||||
}
|
||||
return result
|
||||
|
||||
case "assistant":
|
||||
blocks := decodeContentBlocks(msg.Content)
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeContentBlocks 解码内容块列表
|
||||
func decodeContentBlocks(content any) []canonical.ContentBlock {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(v)}
|
||||
case []any:
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
block := decodeSingleContentBlock(m)
|
||||
if block != nil {
|
||||
blocks = append(blocks, *block)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(blocks) > 0 {
|
||||
return blocks
|
||||
}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
case nil:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
default:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}
|
||||
}
|
||||
}
|
||||
|
||||
// decodeSingleContentBlock 解码单个内容块
|
||||
func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
||||
t, _ := m["type"].(string)
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
return &canonical.ContentBlock{Type: "text", Text: text}
|
||||
case "tool_use":
|
||||
id, _ := m["id"].(string)
|
||||
name, _ := m["name"].(string)
|
||||
input, _ := json.Marshal(m["input"])
|
||||
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}
|
||||
case "tool_result":
|
||||
toolUseID, _ := m["tool_use_id"].(string)
|
||||
isErr := false
|
||||
if ie, ok := m["is_error"].(bool); ok {
|
||||
isErr = ie
|
||||
}
|
||||
var content json.RawMessage
|
||||
if c, ok := m["content"]; ok {
|
||||
switch cv := c.(type) {
|
||||
case string:
|
||||
content = json.RawMessage(fmt.Sprintf("%q", cv))
|
||||
default:
|
||||
content, _ = json.Marshal(cv)
|
||||
}
|
||||
} else {
|
||||
content = json.RawMessage(`""`)
|
||||
}
|
||||
return &canonical.ContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: content,
|
||||
IsError: &isErr,
|
||||
}
|
||||
case "thinking":
|
||||
thinking, _ := m["thinking"].(string)
|
||||
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}
|
||||
case "redacted_thinking":
|
||||
// 丢弃
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeTools 解码工具定义
|
||||
func decodeTools(tools []Tool) []canonical.CanonicalTool {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]canonical.CanonicalTool, len(tools))
|
||||
for i, t := range tools {
|
||||
result[i] = canonical.CanonicalTool{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
InputSchema: t.InputSchema,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// decodeToolChoice 解码工具选择
|
||||
func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
if toolChoice == nil {
|
||||
return nil
|
||||
}
|
||||
switch v := toolChoice.(type) {
|
||||
case string:
|
||||
switch v {
|
||||
case "auto":
|
||||
return canonical.NewToolChoiceAuto()
|
||||
case "none":
|
||||
return canonical.NewToolChoiceNone()
|
||||
case "any":
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
case map[string]any:
|
||||
t, _ := v["type"].(string)
|
||||
switch t {
|
||||
case "auto":
|
||||
return canonical.NewToolChoiceAuto()
|
||||
case "none":
|
||||
return canonical.NewToolChoiceNone()
|
||||
case "any":
|
||||
return canonical.NewToolChoiceAny()
|
||||
case "tool":
|
||||
name, _ := v["name"].(string)
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeParameters 解码请求参数
|
||||
func decodeParameters(req *MessagesRequest) canonical.RequestParameters {
|
||||
params := canonical.RequestParameters{
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
TopK: req.TopK,
|
||||
}
|
||||
if req.MaxTokens > 0 {
|
||||
val := req.MaxTokens
|
||||
params.MaxTokens = &val
|
||||
}
|
||||
if len(req.StopSequences) > 0 {
|
||||
params.StopSequences = req.StopSequences
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// decodeThinking 解码思考配置
|
||||
func decodeThinking(thinking *ThinkingConfig, outputConfig *OutputConfig) *canonical.ThinkingConfig {
|
||||
if thinking == nil {
|
||||
return nil
|
||||
}
|
||||
cfg := &canonical.ThinkingConfig{
|
||||
Type: thinking.Type,
|
||||
BudgetTokens: thinking.BudgetTokens,
|
||||
}
|
||||
if outputConfig != nil && outputConfig.Effort != "" {
|
||||
cfg.Effort = outputConfig.Effort
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// decodeOutputFormat 解码输出格式
|
||||
func decodeOutputFormat(outputConfig *OutputConfig) *canonical.OutputFormat {
|
||||
if outputConfig == nil || outputConfig.Format == nil {
|
||||
return nil
|
||||
}
|
||||
if outputConfig.Format.Type == "json_schema" && outputConfig.Format.Schema != nil {
|
||||
return &canonical.OutputFormat{
|
||||
Type: "json_schema",
|
||||
Name: "output",
|
||||
Schema: outputConfig.Format.Schema,
|
||||
Strict: boolPtr(true),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeResponse 将 Anthropic 响应解码为 Canonical 响应
|
||||
func decodeResponse(body []byte) (*canonical.CanonicalResponse, error) {
|
||||
var resp MessagesResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 Anthropic 响应失败").WithCause(err)
|
||||
}
|
||||
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, block := range resp.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
blocks = append(blocks, canonical.NewTextBlock(block.Text))
|
||||
case "tool_use":
|
||||
blocks = append(blocks, canonical.NewToolUseBlock(block.ID, block.Name, block.Input))
|
||||
case "thinking":
|
||||
blocks = append(blocks, canonical.NewThinkingBlock(block.Thinking))
|
||||
case "redacted_thinking":
|
||||
// 丢弃
|
||||
}
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
|
||||
sr := mapStopReason(resp.StopReason)
|
||||
usage := canonical.CanonicalUsage{
|
||||
InputTokens: resp.Usage.InputTokens,
|
||||
OutputTokens: resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.CacheReadInputTokens != nil {
|
||||
usage.CacheReadTokens = resp.Usage.CacheReadInputTokens
|
||||
}
|
||||
if resp.Usage.CacheCreationInputTokens != nil {
|
||||
usage.CacheCreationTokens = resp.Usage.CacheCreationInputTokens
|
||||
}
|
||||
|
||||
return &canonical.CanonicalResponse{
|
||||
ID: resp.ID,
|
||||
Model: resp.Model,
|
||||
Content: blocks,
|
||||
StopReason: &sr,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// mapStopReason 映射停止原因
|
||||
func mapStopReason(reason string) canonical.StopReason {
|
||||
switch reason {
|
||||
case "end_turn":
|
||||
return canonical.StopReasonEndTurn
|
||||
case "max_tokens":
|
||||
return canonical.StopReasonMaxTokens
|
||||
case "tool_use":
|
||||
return canonical.StopReasonToolUse
|
||||
case "stop_sequence":
|
||||
return canonical.StopReasonStopSequence
|
||||
case "pause_turn":
|
||||
return canonical.StopReason("pause_turn")
|
||||
case "refusal":
|
||||
return canonical.StopReasonRefusal
|
||||
default:
|
||||
return canonical.StopReasonEndTurn
|
||||
}
|
||||
}
|
||||
|
||||
// decodeModelsResponse 解码模型列表响应
|
||||
func decodeModelsResponse(body []byte) (*canonical.CanonicalModelList, error) {
|
||||
var resp ModelsResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
models := make([]canonical.CanonicalModel, len(resp.Data))
|
||||
for i, m := range resp.Data {
|
||||
name := m.DisplayName
|
||||
if name == "" {
|
||||
name = m.ID
|
||||
}
|
||||
models[i] = canonical.CanonicalModel{
|
||||
ID: m.ID,
|
||||
Name: name,
|
||||
Created: parseTimestamp(m.CreatedAt),
|
||||
OwnedBy: "anthropic",
|
||||
}
|
||||
}
|
||||
return &canonical.CanonicalModelList{Models: models}, nil
|
||||
}
|
||||
|
||||
// decodeModelInfoResponse 解码模型详情响应
|
||||
func decodeModelInfoResponse(body []byte) (*canonical.CanonicalModelInfo, error) {
|
||||
var resp ModelInfoResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name := resp.DisplayName
|
||||
if name == "" {
|
||||
name = resp.ID
|
||||
}
|
||||
return &canonical.CanonicalModelInfo{
|
||||
ID: resp.ID,
|
||||
Name: name,
|
||||
Created: parseTimestamp(resp.CreatedAt),
|
||||
OwnedBy: "anthropic",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseTimestamp 解析 RFC 3339 时间戳为 Unix
|
||||
func parseTimestamp(s string) int64 {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, s)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
// formatTimestamp 将 Unix 时间戳格式化为 RFC 3339
|
||||
func formatTimestamp(unix int64) string {
|
||||
if unix == 0 {
|
||||
return time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC).Format(time.RFC3339)
|
||||
}
|
||||
return time.Unix(unix, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// boolPtr 返回 bool 指针
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
331
backend/internal/conversion/anthropic/decoder_test.go
Normal file
331
backend/internal/conversion/anthropic/decoder_test.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecodeRequest_Basic(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "你好"}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-3", req.Model)
|
||||
assert.Len(t, req.Messages, 1)
|
||||
assert.Equal(t, canonical.RoleUser, req.Messages[0].Role)
|
||||
assert.NotNil(t, req.Parameters.MaxTokens)
|
||||
assert.Equal(t, 1024, *req.Parameters.MaxTokens)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_System(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"system": "你是助手",
|
||||
"messages": [
|
||||
{"role": "user", "content": "你好"}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "你是助手", req.System)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_SystemBlocks(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"system": [{"text": "指令1"}, {"text": "指令2"}],
|
||||
"messages": [
|
||||
{"role": "user", "content": "你好"}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
blocks, ok := req.System.([]canonical.SystemBlock)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, blocks, 2)
|
||||
assert.Equal(t, "指令1", blocks[0].Text)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ToolResultSplit(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "查询天气"},
|
||||
{"type": "tool_result", "tool_use_id": "tool_1", "content": "晴天"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
// 用户消息中的 tool_result 应被拆分为独立的 tool 消息
|
||||
assert.Len(t, req.Messages, 2)
|
||||
assert.Equal(t, canonical.RoleUser, req.Messages[0].Role)
|
||||
assert.Equal(t, canonical.RoleTool, req.Messages[1].Role)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_MissingModel(t *testing.T) {
|
||||
body := []byte(`{"max_tokens": 1024, "messages": [{"role": "user", "content": "hi"}]}`)
|
||||
_, err := decodeRequest(body)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "INVALID_INPUT")
|
||||
}
|
||||
|
||||
func TestDecodeRequest_MissingMessages(t *testing.T) {
|
||||
body := []byte(`{"model": "claude-3", "max_tokens": 1024}`)
|
||||
_, err := decodeRequest(body)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "INVALID_INPUT")
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Basic(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [{"type": "text", "text": "你好"}],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "msg_123", resp.ID)
|
||||
assert.Equal(t, "claude-3", resp.Model)
|
||||
assert.Len(t, resp.Content, 1)
|
||||
assert.Equal(t, "你好", resp.Content[0].Text)
|
||||
assert.Equal(t, canonical.StopReasonEndTurn, *resp.StopReason)
|
||||
assert.Equal(t, 10, resp.Usage.InputTokens)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Thinking(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg_456",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "思考过程"},
|
||||
{"type": "text", "text": "回答"}
|
||||
],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Content, 2)
|
||||
assert.Equal(t, "thinking", resp.Content[0].Type)
|
||||
assert.Equal(t, "思考过程", resp.Content[0].Thinking)
|
||||
assert.Equal(t, "text", resp.Content[1].Type)
|
||||
assert.Equal(t, "回答", resp.Content[1].Text)
|
||||
}
|
||||
|
||||
func TestDecodeModelsResponse(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"data": [
|
||||
{"id": "claude-3-opus", "type": "model", "display_name": "Claude 3 Opus", "created_at": "2024-01-15T00:00:00Z"},
|
||||
{"id": "claude-3-sonnet", "type": "model", "created_at": "2024-02-01T00:00:00Z"}
|
||||
],
|
||||
"has_more": false
|
||||
}`)
|
||||
|
||||
list, err := decodeModelsResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, list.Models, 2)
|
||||
assert.Equal(t, "claude-3-opus", list.Models[0].ID)
|
||||
assert.Equal(t, "Claude 3 Opus", list.Models[0].Name)
|
||||
// created_at RFC3339 → Unix
|
||||
assert.NotEqual(t, int64(0), list.Models[0].Created)
|
||||
// 无 display_name 时使用 ID
|
||||
assert.Equal(t, "claude-3-sonnet", list.Models[1].Name)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_InvalidJSON(t *testing.T) {
|
||||
_, err := decodeRequest([]byte(`invalid json`))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "JSON_PARSE_ERROR")
|
||||
}
|
||||
|
||||
func TestDecodeRequest_Thinking(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 5000}
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.Thinking)
|
||||
assert.Equal(t, "enabled", req.Thinking.Type)
|
||||
require.NotNil(t, req.Thinking.BudgetTokens)
|
||||
assert.Equal(t, 5000, *req.Thinking.BudgetTokens)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ThinkingAdaptive(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"thinking": {"type": "adaptive"}
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.Thinking)
|
||||
assert.Equal(t, "adaptive", req.Thinking.Type)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_OutputConfig(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"output_config": {
|
||||
"format": {
|
||||
"type": "json_schema",
|
||||
"schema": {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.OutputFormat)
|
||||
assert.Equal(t, "json_schema", req.OutputFormat.Type)
|
||||
assert.NotNil(t, req.OutputFormat.Schema)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_DisableParallelToolUse(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"disable_parallel_tool_use": true
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.ParallelToolUse)
|
||||
assert.False(t, *req.ParallelToolUse)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_ToolUse(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg_tool",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "tool_1", "name": "search", "input": {"q": "test"}}
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Content, 1)
|
||||
assert.Equal(t, "tool_use", resp.Content[0].Type)
|
||||
assert.Equal(t, "tool_1", resp.Content[0].ID)
|
||||
assert.Equal(t, "search", resp.Content[0].Name)
|
||||
assert.NotNil(t, resp.Content[0].Input)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_RedactedThinking(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg_redacted",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [
|
||||
{"type": "redacted_thinking", "data": "..."},
|
||||
{"type": "text", "text": "回答"}
|
||||
],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Content, 1)
|
||||
assert.Equal(t, "text", resp.Content[0].Type)
|
||||
assert.Equal(t, "回答", resp.Content[0].Text)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_StopReasons(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason string
|
||||
want canonical.StopReason
|
||||
}{
|
||||
{"end_turn→end_turn", "end_turn", canonical.StopReasonEndTurn},
|
||||
{"max_tokens→max_tokens", "max_tokens", canonical.StopReasonMaxTokens},
|
||||
{"tool_use→tool_use", "tool_use", canonical.StopReasonToolUse},
|
||||
{"stop_sequence→stop_sequence", "stop_sequence", canonical.StopReasonStopSequence},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body := []byte(fmt.Sprintf(`{
|
||||
"id": "msg-1",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [{"type": "text", "text": "ok"}],
|
||||
"stop_reason": "%s",
|
||||
"usage": {"input_tokens": 1, "output_tokens": 1}
|
||||
}`, tt.reason))
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.StopReason)
|
||||
assert.Equal(t, tt.want, *resp.StopReason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Usage(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg_usage",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [{"type": "text", "text": "ok"}],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"cache_read_input_tokens": 30
|
||||
}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, resp.Usage.InputTokens)
|
||||
assert.Equal(t, 50, resp.Usage.OutputTokens)
|
||||
require.NotNil(t, resp.Usage.CacheReadTokens)
|
||||
assert.Equal(t, 30, *resp.Usage.CacheReadTokens)
|
||||
}
|
||||
449
backend/internal/conversion/anthropic/encoder.go
Normal file
449
backend/internal/conversion/anthropic/encoder.go
Normal file
@@ -0,0 +1,449 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// encodeRequest 将 Canonical 请求编码为 Anthropic 请求
|
||||
func encodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
result := map[string]any{
|
||||
"model": provider.ModelName,
|
||||
"stream": req.Stream,
|
||||
}
|
||||
|
||||
// max_tokens 必填
|
||||
if req.Parameters.MaxTokens != nil {
|
||||
result["max_tokens"] = *req.Parameters.MaxTokens
|
||||
} else {
|
||||
result["max_tokens"] = 4096
|
||||
}
|
||||
|
||||
// 系统消息
|
||||
if req.System != nil {
|
||||
result["system"] = encodeSystem(req.System)
|
||||
}
|
||||
|
||||
// 消息
|
||||
result["messages"] = encodeMessages(req.Messages)
|
||||
|
||||
// 参数
|
||||
if req.Parameters.Temperature != nil {
|
||||
result["temperature"] = *req.Parameters.Temperature
|
||||
}
|
||||
if req.Parameters.TopP != nil {
|
||||
result["top_p"] = *req.Parameters.TopP
|
||||
}
|
||||
if req.Parameters.TopK != nil {
|
||||
result["top_k"] = *req.Parameters.TopK
|
||||
}
|
||||
if len(req.Parameters.StopSequences) > 0 {
|
||||
result["stop_sequences"] = req.Parameters.StopSequences
|
||||
}
|
||||
|
||||
// 工具
|
||||
if len(req.Tools) > 0 {
|
||||
tools := make([]map[string]any, len(req.Tools))
|
||||
for i, t := range req.Tools {
|
||||
tool := map[string]any{
|
||||
"name": t.Name,
|
||||
"input_schema": t.InputSchema,
|
||||
}
|
||||
if t.Description != "" {
|
||||
tool["description"] = t.Description
|
||||
}
|
||||
tools[i] = tool
|
||||
}
|
||||
result["tools"] = tools
|
||||
}
|
||||
if req.ToolChoice != nil {
|
||||
result["tool_choice"] = encodeToolChoice(req.ToolChoice)
|
||||
}
|
||||
|
||||
// 公共字段
|
||||
if req.UserID != "" {
|
||||
result["metadata"] = map[string]any{"user_id": req.UserID}
|
||||
}
|
||||
if req.ParallelToolUse != nil && !*req.ParallelToolUse {
|
||||
result["disable_parallel_tool_use"] = true
|
||||
}
|
||||
if req.Thinking != nil {
|
||||
result["thinking"] = encodeThinkingConfig(req.Thinking)
|
||||
}
|
||||
|
||||
// output_config
|
||||
outputConfig := map[string]any{}
|
||||
hasOutputConfig := false
|
||||
if req.OutputFormat != nil {
|
||||
of := encodeOutputFormat(req.OutputFormat)
|
||||
if of != nil {
|
||||
outputConfig["format"] = of
|
||||
hasOutputConfig = true
|
||||
}
|
||||
}
|
||||
if req.Thinking != nil && req.Thinking.Effort != "" {
|
||||
outputConfig["effort"] = req.Thinking.Effort
|
||||
hasOutputConfig = true
|
||||
}
|
||||
if hasOutputConfig {
|
||||
result["output_config"] = outputConfig
|
||||
}
|
||||
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 Anthropic 请求失败").WithCause(err)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// encodeSystem 编码系统消息
|
||||
func encodeSystem(system any) any {
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []canonical.SystemBlock:
|
||||
blocks := make([]map[string]any, len(v))
|
||||
for i, b := range v {
|
||||
blocks[i] = map[string]any{"text": b.Text}
|
||||
}
|
||||
return blocks
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// encodeMessages 编码消息列表(含角色约束处理)
|
||||
func encodeMessages(msgs []canonical.CanonicalMessage) []map[string]any {
|
||||
var result []map[string]any
|
||||
|
||||
for _, msg := range msgs {
|
||||
switch msg.Role {
|
||||
case canonical.RoleUser:
|
||||
result = append(result, map[string]any{
|
||||
"role": "user",
|
||||
"content": encodeContentBlocks(msg.Content),
|
||||
})
|
||||
case canonical.RoleAssistant:
|
||||
result = append(result, map[string]any{
|
||||
"role": "assistant",
|
||||
"content": encodeContentBlocks(msg.Content),
|
||||
})
|
||||
case canonical.RoleTool:
|
||||
// tool 角色合并到相邻 user 消息
|
||||
toolResults := filterToolResults(msg.Content)
|
||||
if len(result) > 0 && result[len(result)-1]["role"] == "user" {
|
||||
// 合并到最后一条 user 消息
|
||||
lastContent, ok := result[len(result)-1]["content"].([]map[string]any)
|
||||
if ok {
|
||||
result[len(result)-1]["content"] = append(lastContent, toolResults...)
|
||||
} else {
|
||||
result[len(result)-1]["content"] = toolResults
|
||||
}
|
||||
} else {
|
||||
result = append(result, map[string]any{
|
||||
"role": "user",
|
||||
"content": toolResults,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 确保首消息为 user
|
||||
if len(result) > 0 && result[0]["role"] != "user" {
|
||||
result = append([]map[string]any{{"role": "user", "content": []map[string]any{}}}, result...)
|
||||
}
|
||||
|
||||
// 合并连续同角色消息
|
||||
result = mergeConsecutiveRoles(result)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// encodeContentBlocks 编码内容块列表
|
||||
func encodeContentBlocks(blocks []canonical.ContentBlock) []map[string]any {
|
||||
result := make([]map[string]any, 0, len(blocks))
|
||||
for _, b := range blocks {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
result = append(result, map[string]any{"type": "text", "text": b.Text})
|
||||
case "tool_use":
|
||||
m := map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": b.ID,
|
||||
"name": b.Name,
|
||||
"input": b.Input,
|
||||
}
|
||||
if b.Input == nil {
|
||||
m["input"] = map[string]any{}
|
||||
}
|
||||
result = append(result, m)
|
||||
case "tool_result":
|
||||
m := map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": b.ToolUseID,
|
||||
}
|
||||
if b.Content != nil {
|
||||
var contentStr string
|
||||
if json.Unmarshal(b.Content, &contentStr) == nil {
|
||||
m["content"] = contentStr
|
||||
} else {
|
||||
m["content"] = string(b.Content)
|
||||
}
|
||||
} else {
|
||||
m["content"] = ""
|
||||
}
|
||||
if b.IsError != nil {
|
||||
m["is_error"] = *b.IsError
|
||||
}
|
||||
result = append(result, m)
|
||||
case "thinking":
|
||||
result = append(result, map[string]any{"type": "thinking", "thinking": b.Thinking})
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return []map[string]any{{"type": "text", "text": ""}}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// filterToolResults 过滤工具结果
|
||||
func filterToolResults(blocks []canonical.ContentBlock) []map[string]any {
|
||||
var result []map[string]any
|
||||
for _, b := range blocks {
|
||||
if b.Type == "tool_result" {
|
||||
m := map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": b.ToolUseID,
|
||||
}
|
||||
if b.Content != nil {
|
||||
var contentStr string
|
||||
if json.Unmarshal(b.Content, &contentStr) == nil {
|
||||
m["content"] = contentStr
|
||||
} else {
|
||||
m["content"] = string(b.Content)
|
||||
}
|
||||
} else {
|
||||
m["content"] = ""
|
||||
}
|
||||
if b.IsError != nil {
|
||||
m["is_error"] = *b.IsError
|
||||
}
|
||||
result = append(result, m)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// encodeToolChoice 编码工具选择
|
||||
func encodeToolChoice(choice *canonical.ToolChoice) any {
|
||||
switch choice.Type {
|
||||
case "auto":
|
||||
return map[string]any{"type": "auto"}
|
||||
case "none":
|
||||
return map[string]any{"type": "none"}
|
||||
case "any":
|
||||
return map[string]any{"type": "any"}
|
||||
case "tool":
|
||||
return map[string]any{"type": "tool", "name": choice.Name}
|
||||
}
|
||||
return map[string]any{"type": "auto"}
|
||||
}
|
||||
|
||||
// encodeThinkingConfig 编码思考配置
|
||||
func encodeThinkingConfig(cfg *canonical.ThinkingConfig) map[string]any {
|
||||
switch cfg.Type {
|
||||
case "enabled":
|
||||
m := map[string]any{"type": "enabled"}
|
||||
if cfg.BudgetTokens != nil {
|
||||
m["budget_tokens"] = *cfg.BudgetTokens
|
||||
}
|
||||
return m
|
||||
case "disabled":
|
||||
return map[string]any{"type": "disabled"}
|
||||
case "adaptive":
|
||||
return map[string]any{"type": "adaptive"}
|
||||
}
|
||||
return map[string]any{"type": "disabled"}
|
||||
}
|
||||
|
||||
// encodeOutputFormat 编码输出格式
|
||||
func encodeOutputFormat(format *canonical.OutputFormat) map[string]any {
|
||||
if format == nil {
|
||||
return nil
|
||||
}
|
||||
switch format.Type {
|
||||
case "json_schema":
|
||||
schema := format.Schema
|
||||
if schema == nil {
|
||||
schema = json.RawMessage(`{"type":"object"}`)
|
||||
}
|
||||
return map[string]any{
|
||||
"type": "json_schema",
|
||||
"schema": schema,
|
||||
}
|
||||
case "json_object":
|
||||
return map[string]any{
|
||||
"type": "json_schema",
|
||||
"schema": map[string]any{"type": "object"},
|
||||
}
|
||||
case "text":
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeResponse 将 Canonical 响应编码为 Anthropic 响应
|
||||
func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
blocks := make([]map[string]any, 0, len(resp.Content))
|
||||
for _, b := range resp.Content {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
blocks = append(blocks, map[string]any{"type": "text", "text": b.Text})
|
||||
case "tool_use":
|
||||
m := map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": b.ID,
|
||||
"name": b.Name,
|
||||
"input": b.Input,
|
||||
}
|
||||
if b.Input == nil {
|
||||
m["input"] = map[string]any{}
|
||||
}
|
||||
blocks = append(blocks, m)
|
||||
case "thinking":
|
||||
blocks = append(blocks, map[string]any{"type": "thinking", "thinking": b.Thinking})
|
||||
}
|
||||
}
|
||||
|
||||
sr := "end_turn"
|
||||
if resp.StopReason != nil {
|
||||
sr = mapCanonicalStopReason(*resp.StopReason)
|
||||
}
|
||||
|
||||
usage := map[string]any{
|
||||
"input_tokens": resp.Usage.InputTokens,
|
||||
"output_tokens": resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.CacheReadTokens != nil {
|
||||
usage["cache_read_input_tokens"] = *resp.Usage.CacheReadTokens
|
||||
}
|
||||
if resp.Usage.CacheCreationTokens != nil {
|
||||
usage["cache_creation_input_tokens"] = *resp.Usage.CacheCreationTokens
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"id": resp.ID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": resp.Model,
|
||||
"content": blocks,
|
||||
"stop_reason": sr,
|
||||
"stop_sequence": nil,
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 Anthropic 响应失败").WithCause(err)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// mapCanonicalStopReason 映射 Canonical 停止原因到 Anthropic
|
||||
func mapCanonicalStopReason(reason canonical.StopReason) string {
|
||||
switch reason {
|
||||
case canonical.StopReasonEndTurn, canonical.StopReasonContentFilter:
|
||||
return "end_turn"
|
||||
case canonical.StopReasonMaxTokens:
|
||||
return "max_tokens"
|
||||
case canonical.StopReasonToolUse:
|
||||
return "tool_use"
|
||||
case canonical.StopReasonStopSequence:
|
||||
return "stop_sequence"
|
||||
case canonical.StopReasonRefusal:
|
||||
return "refusal"
|
||||
default:
|
||||
return "end_turn"
|
||||
}
|
||||
}
|
||||
|
||||
// encodeModelsResponse 编码模型列表响应
|
||||
func encodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
|
||||
data := make([]map[string]any, len(list.Models))
|
||||
for i, m := range list.Models {
|
||||
name := m.Name
|
||||
if name == "" {
|
||||
name = m.ID
|
||||
}
|
||||
data[i] = map[string]any{
|
||||
"id": m.ID,
|
||||
"type": "model",
|
||||
"display_name": name,
|
||||
"created_at": formatTimestamp(m.Created),
|
||||
}
|
||||
}
|
||||
|
||||
var firstID, lastID *string
|
||||
if len(list.Models) > 0 {
|
||||
fid := list.Models[0].ID
|
||||
firstID = &fid
|
||||
lid := list.Models[len(list.Models)-1].ID
|
||||
lastID = &lid
|
||||
}
|
||||
|
||||
return json.Marshal(map[string]any{
|
||||
"data": data,
|
||||
"has_more": false,
|
||||
"first_id": firstID,
|
||||
"last_id": lastID,
|
||||
})
|
||||
}
|
||||
|
||||
// encodeModelInfoResponse 编码模型详情响应
|
||||
func encodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
|
||||
name := info.Name
|
||||
if name == "" {
|
||||
name = info.ID
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"id": info.ID,
|
||||
"type": "model",
|
||||
"display_name": name,
|
||||
"created_at": formatTimestamp(info.Created),
|
||||
})
|
||||
}
|
||||
|
||||
// mergeConsecutiveRoles 合并连续同角色消息
|
||||
func mergeConsecutiveRoles(messages []map[string]any) []map[string]any {
|
||||
if len(messages) <= 1 {
|
||||
return messages
|
||||
}
|
||||
var result []map[string]any
|
||||
for _, msg := range messages {
|
||||
if len(result) > 0 {
|
||||
lastRole := result[len(result)-1]["role"]
|
||||
currRole := msg["role"]
|
||||
if lastRole == currRole {
|
||||
// 合并 content
|
||||
lastContent := result[len(result)-1]["content"]
|
||||
currContent := msg["content"]
|
||||
switch lv := lastContent.(type) {
|
||||
case []map[string]any:
|
||||
if cv, ok := currContent.([]map[string]any); ok {
|
||||
result[len(result)-1]["content"] = append(lv, cv...)
|
||||
}
|
||||
case string:
|
||||
if cv, ok := currContent.(string); ok {
|
||||
result[len(result)-1]["content"] = lv + cv
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
result = append(result, msg)
|
||||
}
|
||||
return result
|
||||
}
|
||||
350
backend/internal/conversion/anthropic/encoder_test.go
Normal file
350
backend/internal/conversion/anthropic/encoder_test.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncodeRequest_Basic(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Stream: true,
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{
|
||||
{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "my-model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "my-model", result["model"])
|
||||
assert.Equal(t, true, result["stream"])
|
||||
assert.Equal(t, float64(1024), result["max_tokens"])
|
||||
|
||||
msgs := result["messages"].([]any)
|
||||
assert.Len(t, msgs, 1)
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{
|
||||
{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("查询")}},
|
||||
{Role: canonical.RoleAssistant, Content: []canonical.ContentBlock{canonical.NewToolUseBlock("tool_1", "search", json.RawMessage(`{"q":"test"}`))}},
|
||||
{Role: canonical.RoleTool, Content: []canonical.ContentBlock{canonical.NewToolResultBlock("tool_1", "结果", false)}},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
|
||||
// tool 消息应被合并到相邻 user 消息
|
||||
foundToolResult := false
|
||||
for _, m := range msgs {
|
||||
msgMap := m.(map[string]any)
|
||||
if msgMap["role"] == "user" {
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if ok {
|
||||
for _, c := range content {
|
||||
block := c.(map[string]any)
|
||||
if block["type"] == "tool_result" {
|
||||
foundToolResult = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(t, foundToolResult)
|
||||
}
|
||||
|
||||
func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{
|
||||
{Role: canonical.RoleAssistant, Content: []canonical.ContentBlock{canonical.NewTextBlock("前置")}},
|
||||
{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
firstMsg := msgs[0].(map[string]any)
|
||||
assert.Equal(t, "user", firstMsg["role"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ThinkingEnabled(t *testing.T) {
|
||||
budget := 10000
|
||||
maxTokens := 8096
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Thinking: &canonical.ThinkingConfig{Type: "enabled", BudgetTokens: &budget},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
thinking, ok := result["thinking"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "enabled", thinking["type"])
|
||||
assert.Equal(t, float64(10000), thinking["budget_tokens"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_Basic(t *testing.T) {
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "msg_1",
|
||||
Model: "claude-3",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "msg_1", result["id"])
|
||||
assert.Equal(t, "message", result["type"])
|
||||
assert.Equal(t, "assistant", result["role"])
|
||||
assert.Equal(t, "end_turn", result["stop_reason"])
|
||||
|
||||
content := result["content"].([]any)
|
||||
assert.Len(t, content, 1)
|
||||
block := content[0].(map[string]any)
|
||||
assert.Equal(t, "text", block["type"])
|
||||
assert.Equal(t, "你好", block["text"])
|
||||
}
|
||||
|
||||
func TestEncodeModelsResponse(t *testing.T) {
|
||||
ts := time.Date(2024, 3, 15, 0, 0, 0, 0, time.UTC).Unix()
|
||||
list := &canonical.CanonicalModelList{
|
||||
Models: []canonical.CanonicalModel{
|
||||
{ID: "claude-3-opus", Name: "Claude 3 Opus", Created: ts, OwnedBy: "anthropic"},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := encodeModelsResponse(list)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
data := result["data"].([]any)
|
||||
assert.Len(t, data, 1)
|
||||
|
||||
model := data[0].(map[string]any)
|
||||
assert.Equal(t, "claude-3-opus", model["id"])
|
||||
// created 应为 RFC3339 格式
|
||||
createdAt, ok := model["created_at"].(string)
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, createdAt, "2024")
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ThinkingDisabled(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
_, hasThinking := result["thinking"]
|
||||
assert.False(t, hasThinking)
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ThinkingAdaptive(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Thinking: &canonical.ThinkingConfig{Type: "adaptive"},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
thinking, ok := result["thinking"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "adaptive", thinking["type"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_OutputFormat_JSONSchema(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`)
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
OutputFormat: &canonical.OutputFormat{
|
||||
Type: "json_schema",
|
||||
Schema: schema,
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
oc, ok := result["output_config"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
format, ok := oc["format"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "json_schema", format["type"])
|
||||
assert.NotNil(t, format["schema"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_OutputFormat_JSON(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
OutputFormat: &canonical.OutputFormat{
|
||||
Type: "json_object",
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
oc, ok := result["output_config"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
format, ok := oc["format"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "json_schema", format["type"])
|
||||
schemaMap, ok := format["schema"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "object", schemaMap["type"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{
|
||||
{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("A")}},
|
||||
{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("B")}},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
assert.Len(t, msgs, 1)
|
||||
userMsg := msgs[0].(map[string]any)
|
||||
assert.Equal(t, "user", userMsg["role"])
|
||||
content := userMsg["content"].([]any)
|
||||
assert.Len(t, content, 2)
|
||||
}
|
||||
|
||||
func TestEncodeResponse_ContentFilter(t *testing.T) {
|
||||
sr := canonical.StopReasonContentFilter
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "msg-cf",
|
||||
Model: "claude-3",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")},
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "end_turn", result["stop_reason"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_ReasoningTokens(t *testing.T) {
|
||||
reasoning := 100
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "msg-rt",
|
||||
Model: "claude-3",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5, ReasoningTokens: &reasoning},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage := result["usage"].(map[string]any)
|
||||
_, hasReasoning := usage["reasoning_tokens"]
|
||||
assert.False(t, hasReasoning)
|
||||
}
|
||||
|
||||
func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
sr := canonical.StopReasonToolUse
|
||||
input := json.RawMessage(`{"q":"test"}`)
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "msg-tool",
|
||||
Model: "claude-3",
|
||||
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("tool_1", "search", input)},
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
content := result["content"].([]any)
|
||||
assert.Len(t, content, 1)
|
||||
block := content[0].(map[string]any)
|
||||
assert.Equal(t, "tool_use", block["type"])
|
||||
assert.Equal(t, "tool_1", block["id"])
|
||||
assert.Equal(t, "search", block["name"])
|
||||
}
|
||||
283
backend/internal/conversion/anthropic/stream_decoder.go
Normal file
283
backend/internal/conversion/anthropic/stream_decoder.go
Normal file
@@ -0,0 +1,283 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamDecoder Anthropic 流式解码器
|
||||
type StreamDecoder struct {
|
||||
messageStarted bool
|
||||
redactedBlocks map[int]bool
|
||||
utf8Remainder []byte
|
||||
accumulatedUsage *canonical.CanonicalUsage
|
||||
}
|
||||
|
||||
// NewStreamDecoder 创建 Anthropic 流式解码器
|
||||
func NewStreamDecoder() *StreamDecoder {
|
||||
return &StreamDecoder{
|
||||
redactedBlocks: make(map[int]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 处理原始 SSE chunk
|
||||
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
data := rawChunk
|
||||
if len(d.utf8Remainder) > 0 {
|
||||
data = append(d.utf8Remainder, rawChunk...)
|
||||
d.utf8Remainder = nil
|
||||
}
|
||||
|
||||
if !utf8.Valid(data) {
|
||||
validEnd := len(data)
|
||||
for !utf8.Valid(data[:validEnd]) {
|
||||
validEnd--
|
||||
}
|
||||
d.utf8Remainder = append(d.utf8Remainder, data[validEnd:]...)
|
||||
data = data[:validEnd]
|
||||
}
|
||||
|
||||
var events []canonical.CanonicalStreamEvent
|
||||
text := string(data)
|
||||
|
||||
// 解析命名 SSE 事件
|
||||
var eventType string
|
||||
var eventData string
|
||||
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
eventType = strings.TrimPrefix(line, "event: ")
|
||||
} else if strings.HasPrefix(line, "data: ") {
|
||||
eventData = strings.TrimPrefix(line, "data: ")
|
||||
if eventType != "" && eventData != "" {
|
||||
chunkEvents := d.processEvent(eventType, []byte(eventData))
|
||||
events = append(events, chunkEvents...)
|
||||
}
|
||||
eventType = ""
|
||||
eventData = ""
|
||||
} else if line == "" {
|
||||
// SSE 事件分隔符
|
||||
}
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// Flush 刷新解码器状态
|
||||
func (d *StreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||
return nil
|
||||
}
|
||||
|
||||
// processEvent 处理单个命名 SSE 事件
|
||||
func (d *StreamDecoder) processEvent(eventType string, data []byte) []canonical.CanonicalStreamEvent {
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
return d.processMessageStart(data)
|
||||
case "content_block_start":
|
||||
return d.processContentBlockStart(data)
|
||||
case "content_block_delta":
|
||||
return d.processContentBlockDelta(data)
|
||||
case "content_block_stop":
|
||||
return d.processContentBlockStop(data)
|
||||
case "message_delta":
|
||||
return d.processMessageDelta(data)
|
||||
case "message_stop":
|
||||
return d.processMessageStop(data)
|
||||
case "ping":
|
||||
return []canonical.CanonicalStreamEvent{canonical.NewPingEvent()}
|
||||
case "error":
|
||||
return d.processError(data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// processMessageStart 处理消息开始事件
|
||||
func (d *StreamDecoder) processMessageStart(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Usage *struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if msgRaw, ok := raw["message"]; ok {
|
||||
if err := json.Unmarshal(msgRaw, &msg); err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
event := canonical.NewMessageStartEvent(msg.ID, msg.Model)
|
||||
if msg.Usage != nil {
|
||||
usage := &canonical.CanonicalUsage{
|
||||
InputTokens: msg.Usage.InputTokens,
|
||||
OutputTokens: msg.Usage.OutputTokens,
|
||||
}
|
||||
event = canonical.NewMessageStartEventWithUsage(msg.ID, msg.Model, usage)
|
||||
d.accumulatedUsage = usage
|
||||
}
|
||||
|
||||
d.messageStarted = true
|
||||
return []canonical.CanonicalStreamEvent{event}
|
||||
}
|
||||
|
||||
// processContentBlockStart 处理内容块开始事件
|
||||
func (d *StreamDecoder) processContentBlockStart(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Index int `json:"index"`
|
||||
ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Thinking string `json:"thinking"`
|
||||
Data string `json:"data"`
|
||||
} `json:"content_block"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查需要丢弃的块类型
|
||||
switch raw.ContentBlock.Type {
|
||||
case "redacted_thinking", "server_tool_use", "web_search_tool_result",
|
||||
"code_execution_tool_result":
|
||||
d.redactedBlocks[raw.Index] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if d.redactedBlocks[raw.Index] {
|
||||
return nil
|
||||
}
|
||||
|
||||
block := canonical.StreamContentBlock{
|
||||
Type: raw.ContentBlock.Type,
|
||||
Text: raw.ContentBlock.Text,
|
||||
ID: raw.ContentBlock.ID,
|
||||
Name: raw.ContentBlock.Name,
|
||||
Thinking: raw.ContentBlock.Thinking,
|
||||
}
|
||||
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewContentBlockStartEvent(raw.Index, block),
|
||||
}
|
||||
}
|
||||
|
||||
// processContentBlockDelta 处理内容块增量事件
|
||||
func (d *StreamDecoder) processContentBlockDelta(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Index int `json:"index"`
|
||||
Delta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
PartialJSON string `json:"partial_json"`
|
||||
Thinking string `json:"thinking"`
|
||||
} `json:"delta"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否在丢弃的块中
|
||||
if d.redactedBlocks[raw.Index] {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 丢弃协议特有 delta 类型
|
||||
switch raw.Delta.Type {
|
||||
case "citations_delta", "signature_delta":
|
||||
return nil
|
||||
}
|
||||
|
||||
delta := canonical.StreamDelta{
|
||||
Type: raw.Delta.Type,
|
||||
Text: raw.Delta.Text,
|
||||
PartialJSON: raw.Delta.PartialJSON,
|
||||
Thinking: raw.Delta.Thinking,
|
||||
}
|
||||
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewContentBlockDeltaEvent(raw.Index, delta),
|
||||
}
|
||||
}
|
||||
|
||||
// processContentBlockStop 处理内容块结束事件
|
||||
func (d *StreamDecoder) processContentBlockStop(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Index int `json:"index"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, redacted := d.redactedBlocks[raw.Index]; redacted {
|
||||
delete(d.redactedBlocks, raw.Index)
|
||||
return nil
|
||||
}
|
||||
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewContentBlockStopEvent(raw.Index),
|
||||
}
|
||||
}
|
||||
|
||||
// processMessageDelta 处理消息增量事件
|
||||
func (d *StreamDecoder) processMessageDelta(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Delta struct {
|
||||
StopReason string `json:"stop_reason"`
|
||||
} `json:"delta"`
|
||||
Usage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sr := mapStopReason(raw.Delta.StopReason)
|
||||
usage := &canonical.CanonicalUsage{
|
||||
OutputTokens: raw.Usage.OutputTokens,
|
||||
}
|
||||
|
||||
if d.accumulatedUsage != nil {
|
||||
d.accumulatedUsage.OutputTokens += raw.Usage.OutputTokens
|
||||
}
|
||||
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewMessageDeltaEventWithUsage(sr, usage),
|
||||
}
|
||||
}
|
||||
|
||||
// processMessageStop 处理消息结束事件
|
||||
func (d *StreamDecoder) processMessageStop(data []byte) []canonical.CanonicalStreamEvent {
|
||||
return []canonical.CanonicalStreamEvent{canonical.NewMessageStopEvent()}
|
||||
}
|
||||
|
||||
// processError 处理错误事件
|
||||
func (d *StreamDecoder) processError(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Error struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewErrorEvent("stream_error", fmt.Sprintf("解析错误事件失败: %s", string(data))),
|
||||
}
|
||||
}
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewErrorEvent(raw.Error.Type, raw.Error.Message),
|
||||
}
|
||||
}
|
||||
274
backend/internal/conversion/anthropic/stream_decoder_test.go
Normal file
274
backend/internal/conversion/anthropic/stream_decoder_test.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func makeAnthropicEvent(eventType string, data any) []byte {
|
||||
dataBytes, _ := json.Marshal(data)
|
||||
return []byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, string(dataBytes)))
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MessageStart(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": "msg_1",
|
||||
"model": "claude-3",
|
||||
"usage": map[string]any{"input_tokens": 10, "output_tokens": 0},
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("message_start", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
assert.Equal(t, canonical.EventMessageStart, events[0].Type)
|
||||
assert.Equal(t, "msg_1", events[0].Message.ID)
|
||||
assert.Equal(t, "claude-3", events[0].Message.Model)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
deltaType string
|
||||
deltaData map[string]any
|
||||
checkField string
|
||||
checkValue string
|
||||
}{
|
||||
{
|
||||
name: "text_delta",
|
||||
deltaType: "text_delta",
|
||||
deltaData: map[string]any{"type": "text_delta", "text": "你好"},
|
||||
checkField: "text",
|
||||
checkValue: "你好",
|
||||
},
|
||||
{
|
||||
name: "input_json_delta",
|
||||
deltaType: "input_json_delta",
|
||||
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
|
||||
checkField: "partial_json",
|
||||
checkValue: "{\"key\":",
|
||||
},
|
||||
{
|
||||
name: "thinking_delta",
|
||||
deltaType: "thinking_delta",
|
||||
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
|
||||
checkField: "thinking",
|
||||
checkValue: "思考中",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": tt.deltaData,
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_delta", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
assert.Equal(t, canonical.EventContentBlockDelta, events[0].Type)
|
||||
assert.NotNil(t, events[0].Delta)
|
||||
|
||||
switch tt.checkField {
|
||||
case "text":
|
||||
assert.Equal(t, tt.checkValue, events[0].Delta.Text)
|
||||
case "partial_json":
|
||||
assert.Equal(t, tt.checkValue, events[0].Delta.PartialJSON)
|
||||
case "thinking":
|
||||
assert.Equal(t, tt.checkValue, events[0].Delta.Thinking)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamDecoder_RedactedThinking(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
// redacted_thinking block start 应被抑制
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": 1,
|
||||
"content_block": map[string]any{
|
||||
"type": "redacted_thinking",
|
||||
"data": "redacted-data",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_start", payload)
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
assert.True(t, d.redactedBlocks[1])
|
||||
}
|
||||
|
||||
func TestStreamDecoder_RedactedBlockStopSuppressed(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
d.redactedBlocks[2] = true
|
||||
|
||||
// content_block_stop 对 redacted block 返回 nil
|
||||
payload := map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": 2,
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_stop", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
// 应清理 redactedBlocks
|
||||
_, exists := d.redactedBlocks[2]
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ContentBlockStart(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_start", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventContentBlockStart, events[0].Type)
|
||||
require.NotNil(t, events[0].ContentBlock)
|
||||
assert.Equal(t, "text", events[0].ContentBlock.Type)
|
||||
require.NotNil(t, events[0].Index)
|
||||
assert.Equal(t, 0, *events[0].Index)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ContentBlockStart_ToolUse(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": 1,
|
||||
"content_block": map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_1",
|
||||
"name": "search",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_start", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventContentBlockStart, events[0].Type)
|
||||
require.NotNil(t, events[0].ContentBlock)
|
||||
assert.Equal(t, "tool_use", events[0].ContentBlock.Type)
|
||||
assert.Equal(t, "toolu_1", events[0].ContentBlock.ID)
|
||||
assert.Equal(t, "search", events[0].ContentBlock.Name)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ContentBlockStop(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": 0,
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_stop", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventContentBlockStop, events[0].Type)
|
||||
require.NotNil(t, events[0].Index)
|
||||
assert.Equal(t, 0, *events[0].Index)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MessageDelta(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": "end_turn",
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"output_tokens": 42,
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("message_delta", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventMessageDelta, events[0].Type)
|
||||
require.NotNil(t, events[0].StopReason)
|
||||
assert.Equal(t, canonical.StopReasonEndTurn, *events[0].StopReason)
|
||||
require.NotNil(t, events[0].Usage)
|
||||
assert.Equal(t, 42, events[0].Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MessageStop(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
raw := makeAnthropicEvent("message_stop", map[string]any{"type": "message_stop"})
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventMessageStop, events[0].Type)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_Ping(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
raw := makeAnthropicEvent("ping", map[string]any{"type": "ping"})
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventPing, events[0].Type)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_Error(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": "overloaded_error",
|
||||
"message": "服务过载",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("error", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventError, events[0].Type)
|
||||
require.NotNil(t, events[0].Error)
|
||||
assert.Equal(t, "overloaded_error", events[0].Error.Type)
|
||||
assert.Equal(t, "服务过载", events[0].Error.Message)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_RedactedDeltaSuppressed(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
d.redactedBlocks[1] = true
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 1,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": "被抑制的内容",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_delta", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
}
|
||||
188
backend/internal/conversion/anthropic/stream_encoder.go
Normal file
188
backend/internal/conversion/anthropic/stream_encoder.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamEncoder Anthropic 流式编码器
|
||||
type StreamEncoder struct{}
|
||||
|
||||
// NewStreamEncoder 创建 Anthropic 流式编码器
|
||||
func NewStreamEncoder() *StreamEncoder {
|
||||
return &StreamEncoder{}
|
||||
}
|
||||
|
||||
// EncodeEvent 编码 Canonical 事件为 Anthropic 命名 SSE 事件
|
||||
func (e *StreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
switch event.Type {
|
||||
case canonical.EventMessageStart:
|
||||
return e.encodeMessageStart(event)
|
||||
case canonical.EventContentBlockStart:
|
||||
return e.encodeContentBlockStart(event)
|
||||
case canonical.EventContentBlockDelta:
|
||||
return e.encodeContentBlockDelta(event)
|
||||
case canonical.EventContentBlockStop:
|
||||
return e.encodeContentBlockStop(event)
|
||||
case canonical.EventMessageDelta:
|
||||
return e.encodeMessageDelta(event)
|
||||
case canonical.EventMessageStop:
|
||||
return e.encodeMessageStop(event)
|
||||
case canonical.EventPing:
|
||||
return e.encodePing()
|
||||
case canonical.EventError:
|
||||
return e.encodeError(event)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush 刷新缓冲区(无缓冲)
|
||||
func (e *StreamEncoder) Flush() [][]byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeMessageStart 编码消息开始事件
|
||||
func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
payload := map[string]any{
|
||||
"type": "message_start",
|
||||
}
|
||||
if event.Message != nil {
|
||||
msg := map[string]any{
|
||||
"id": event.Message.ID,
|
||||
"model": event.Message.Model,
|
||||
"role": "assistant",
|
||||
}
|
||||
if event.Message.Usage != nil {
|
||||
usage := map[string]any{
|
||||
"input_tokens": event.Message.Usage.InputTokens,
|
||||
"output_tokens": event.Message.Usage.OutputTokens,
|
||||
}
|
||||
msg["usage"] = usage
|
||||
}
|
||||
payload["message"] = msg
|
||||
}
|
||||
return e.marshalEvent("message_start", payload)
|
||||
}
|
||||
|
||||
// encodeContentBlockStart 编码内容块开始事件
|
||||
func (e *StreamEncoder) encodeContentBlockStart(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.ContentBlock == nil || event.Index == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cb := map[string]any{
|
||||
"type": event.ContentBlock.Type,
|
||||
}
|
||||
switch event.ContentBlock.Type {
|
||||
case "text":
|
||||
cb["text"] = ""
|
||||
case "tool_use":
|
||||
cb["id"] = event.ContentBlock.ID
|
||||
cb["name"] = event.ContentBlock.Name
|
||||
cb["input"] = map[string]any{}
|
||||
case "thinking":
|
||||
cb["thinking"] = ""
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": *event.Index,
|
||||
"content_block": cb,
|
||||
}
|
||||
return e.marshalEvent("content_block_start", payload)
|
||||
}
|
||||
|
||||
// encodeContentBlockDelta 编码内容块增量事件
|
||||
func (e *StreamEncoder) encodeContentBlockDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Delta == nil || event.Index == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
delta := map[string]any{
|
||||
"type": event.Delta.Type,
|
||||
}
|
||||
switch canonical.DeltaType(event.Delta.Type) {
|
||||
case canonical.DeltaTypeText:
|
||||
delta["text"] = event.Delta.Text
|
||||
case canonical.DeltaTypeInputJSON:
|
||||
delta["partial_json"] = event.Delta.PartialJSON
|
||||
case canonical.DeltaTypeThinking:
|
||||
delta["thinking"] = event.Delta.Thinking
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": *event.Index,
|
||||
"delta": delta,
|
||||
}
|
||||
return e.marshalEvent("content_block_delta", payload)
|
||||
}
|
||||
|
||||
// encodeContentBlockStop 编码内容块结束事件
|
||||
func (e *StreamEncoder) encodeContentBlockStop(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Index == nil {
|
||||
return nil
|
||||
}
|
||||
payload := map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": *event.Index,
|
||||
}
|
||||
return e.marshalEvent("content_block_stop", payload)
|
||||
}
|
||||
|
||||
// encodeMessageDelta 编码消息增量事件
|
||||
func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
delta := map[string]any{}
|
||||
if event.StopReason != nil {
|
||||
delta["stop_reason"] = mapCanonicalStopReason(*event.StopReason)
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": delta,
|
||||
}
|
||||
if event.Usage != nil {
|
||||
payload["usage"] = map[string]any{
|
||||
"output_tokens": event.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return e.marshalEvent("message_delta", payload)
|
||||
}
|
||||
|
||||
// encodeMessageStop 编码消息结束事件
|
||||
func (e *StreamEncoder) encodeMessageStop(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
payload := map[string]any{"type": "message_stop"}
|
||||
return e.marshalEvent("message_stop", payload)
|
||||
}
|
||||
|
||||
// encodePing 编码心跳事件
|
||||
func (e *StreamEncoder) encodePing() [][]byte {
|
||||
payload := map[string]any{"type": "ping"}
|
||||
return e.marshalEvent("ping", payload)
|
||||
}
|
||||
|
||||
// encodeError 编码错误事件
|
||||
func (e *StreamEncoder) encodeError(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Error == nil {
|
||||
return nil
|
||||
}
|
||||
payload := map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": event.Error.Type,
|
||||
"message": event.Error.Message,
|
||||
},
|
||||
}
|
||||
return e.marshalEvent("error", payload)
|
||||
}
|
||||
|
||||
// marshalEvent 序列化为 Anthropic 命名 SSE 事件
|
||||
func (e *StreamEncoder) marshalEvent(eventType string, payload map[string]any) [][]byte {
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return [][]byte{[]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, data))}
|
||||
}
|
||||
242
backend/internal/conversion/anthropic/stream_encoder_test.go
Normal file
242
backend/internal/conversion/anthropic/stream_encoder_test.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStreamEncoder_MessageStart(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewMessageStartEvent("msg_1", "claude-3")
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: message_start\n"))
|
||||
assert.Contains(t, s, "data: ")
|
||||
assert.Contains(t, s, "msg_1")
|
||||
assert.Contains(t, s, "claude-3")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockDelta(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{Type: "text_delta", Text: "你好"})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: content_block_delta\n"))
|
||||
assert.Contains(t, s, "你好")
|
||||
|
||||
// 验证 JSON 格式
|
||||
lines := strings.Split(s, "\n")
|
||||
var dataLine string
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
dataLine = strings.TrimPrefix(l, "data: ")
|
||||
break
|
||||
}
|
||||
}
|
||||
var payload map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(dataLine), &payload))
|
||||
assert.Equal(t, "content_block_delta", payload["type"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageStop(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewMessageStopEvent()
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: message_stop\n"))
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: ""})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: content_block_start\n"))
|
||||
assert.Contains(t, s, "data: ")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
assert.Equal(t, "text", cb["type"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockStartEvent(1, canonical.StreamContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: "toolu_1",
|
||||
Name: "search",
|
||||
})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "toolu_1")
|
||||
assert.Contains(t, s, "search")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
assert.Equal(t, "tool_use", cb["type"])
|
||||
assert.Equal(t, "toolu_1", cb["id"])
|
||||
assert.Equal(t, "search", cb["name"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "thinking", Thinking: ""})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "thinking")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
assert.Equal(t, "thinking", cb["type"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockStop(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
idx := 2
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventContentBlockStop,
|
||||
Index: &idx,
|
||||
}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: content_block_stop\n"))
|
||||
assert.Contains(t, s, "content_block_stop")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
sr := canonical.StopReasonEndTurn
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventMessageDelta,
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "stop_reason")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
delta := payload["delta"].(map[string]any)
|
||||
assert.Equal(t, "end_turn", delta["stop_reason"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
usage := canonical.CanonicalUsage{OutputTokens: 88}
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventMessageDelta,
|
||||
Usage: &usage,
|
||||
}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "output_tokens")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
u := payload["usage"].(map[string]any)
|
||||
assert.Equal(t, float64(88), u["output_tokens"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Ping(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewPingEvent()
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: ping\n"))
|
||||
assert.Contains(t, s, "ping")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Error(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewErrorEvent("overloaded_error", "服务过载")
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: error\n"))
|
||||
assert.Contains(t, s, "overloaded_error")
|
||||
assert.Contains(t, s, "服务过载")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Flush_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
chunks := e.Flush()
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_UnknownEvent_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.CanonicalStreamEvent{Type: "unknown_event_type"}
|
||||
chunks := e.EncodeEvent(event)
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
183
backend/internal/conversion/anthropic/types.go
Normal file
183
backend/internal/conversion/anthropic/types.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// MessagesRequest Anthropic Messages 请求
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Metadata *RequestMetadata `json:"metadata,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
OutputConfig *OutputConfig `json:"output_config,omitempty"`
|
||||
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
Container any `json:"container,omitempty"`
|
||||
}
|
||||
|
||||
// RequestMetadata 请求元数据
|
||||
type RequestMetadata struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// ThinkingConfig 思考配置
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"`
|
||||
BudgetTokens *int `json:"budget_tokens,omitempty"`
|
||||
Display string `json:"display,omitempty"`
|
||||
}
|
||||
|
||||
// OutputConfig 输出配置
|
||||
type OutputConfig struct {
|
||||
Format *OutputFormatConfig `json:"format,omitempty"`
|
||||
Effort string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
// OutputFormatConfig 输出格式配置
|
||||
type OutputFormatConfig struct {
|
||||
Type string `json:"type"`
|
||||
Schema json.RawMessage `json:"schema,omitempty"`
|
||||
}
|
||||
|
||||
// Message Anthropic 消息
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
// TextContent 文本内容块
|
||||
type TextContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// ToolUseContent 工具调用内容块
|
||||
type ToolUseContent struct {
|
||||
Type string `json:"type"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Input json.RawMessage `json:"input"`
|
||||
}
|
||||
|
||||
// ToolResultContent 工具结果内容块
|
||||
type ToolResultContent struct {
|
||||
Type string `json:"type"`
|
||||
ToolUseID string `json:"tool_use_id"`
|
||||
Content any `json:"content"`
|
||||
IsError *bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
// ThinkingContent 思考内容块
|
||||
type ThinkingContent struct {
|
||||
Type string `json:"type"`
|
||||
Thinking string `json:"thinking"`
|
||||
}
|
||||
|
||||
// RedactedThinkingContent 已编辑思考内容块
|
||||
type RedactedThinkingContent struct {
|
||||
Type string `json:"type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// Tool Anthropic 工具定义
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"`
|
||||
}
|
||||
|
||||
// MessagesResponse Anthropic Messages 响应
|
||||
type MessagesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
StopDetails any `json:"stop_details,omitempty"`
|
||||
Container any `json:"container,omitempty"`
|
||||
Usage ResponseUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// ContentBlock Anthropic 响应内容块
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ResponseUsage 响应用量
|
||||
type ResponseUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"`
|
||||
CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ModelsResponse Anthropic 模型列表响应
|
||||
type ModelsResponse struct {
|
||||
Data []ModelItem `json:"data"`
|
||||
HasMore bool `json:"has_more"`
|
||||
FirstID *string `json:"first_id,omitempty"`
|
||||
LastID *string `json:"last_id,omitempty"`
|
||||
}
|
||||
|
||||
// ModelItem Anthropic 模型项
|
||||
type ModelItem struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
}
|
||||
|
||||
// ModelInfoResponse Anthropic 模型详情响应
|
||||
type ModelInfoResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingRequest Anthropic 不支持嵌入,但定义类型用于接口兼容
|
||||
type EmbeddingRequest struct{}
|
||||
|
||||
// EmbeddingResponse Anthropic 不支持嵌入
|
||||
type EmbeddingResponse struct{}
|
||||
|
||||
// RerankRequest Anthropic 不支持重排序
|
||||
type RerankRequest struct{}
|
||||
|
||||
// RerankResponse Anthropic 不支持重排序
|
||||
type RerankResponse struct{}
|
||||
|
||||
// ErrorResponse Anthropic 错误响应
|
||||
type ErrorResponse struct {
|
||||
Type string `json:"type"`
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail 错误详情
|
||||
type ErrorDetail struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// SSEEvent SSE 事件
|
||||
type SSEEvent struct {
|
||||
EventType string
|
||||
Data json.RawMessage
|
||||
}
|
||||
71
backend/internal/conversion/canonical/extended.go
Normal file
71
backend/internal/conversion/canonical/extended.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package canonical
|
||||
|
||||
// CanonicalModel 规范模型
|
||||
type CanonicalModel struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Created int64 `json:"created,omitempty"`
|
||||
OwnedBy string `json:"owned_by,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalModelList 规范模型列表
|
||||
type CanonicalModelList struct {
|
||||
Models []CanonicalModel `json:"models"`
|
||||
}
|
||||
|
||||
// CanonicalModelInfo 规范模型详情
|
||||
type CanonicalModelInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Created int64 `json:"created,omitempty"`
|
||||
OwnedBy string `json:"owned_by,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalEmbeddingRequest 规范嵌入请求
|
||||
type CanonicalEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"` // string 或 []string
|
||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||
Dimensions *int `json:"dimensions,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalEmbeddingResponse 规范嵌入响应
|
||||
type CanonicalEmbeddingResponse struct {
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage EmbeddingUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// EmbeddingData 嵌入数据项
|
||||
type EmbeddingData struct {
|
||||
Index int `json:"index"`
|
||||
Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串
|
||||
}
|
||||
|
||||
// EmbeddingUsage 嵌入用量
|
||||
type EmbeddingUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// CanonicalRerankRequest 规范重排序请求
|
||||
type CanonicalRerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalRerankResponse 规范重排序响应
|
||||
type CanonicalRerankResponse struct {
|
||||
Results []RerankResult `json:"results"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// RerankResult 重排序结果项
|
||||
type RerankResult struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
Document *string `json:"document,omitempty"`
|
||||
}
|
||||
156
backend/internal/conversion/canonical/stream.go
Normal file
156
backend/internal/conversion/canonical/stream.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package canonical
|
||||
|
||||
// StreamEventType 流式事件类型枚举
|
||||
type StreamEventType string
|
||||
|
||||
const (
|
||||
EventMessageStart StreamEventType = "message_start"
|
||||
EventContentBlockStart StreamEventType = "content_block_start"
|
||||
EventContentBlockDelta StreamEventType = "content_block_delta"
|
||||
EventContentBlockStop StreamEventType = "content_block_stop"
|
||||
EventMessageDelta StreamEventType = "message_delta"
|
||||
EventMessageStop StreamEventType = "message_stop"
|
||||
EventError StreamEventType = "error"
|
||||
EventPing StreamEventType = "ping"
|
||||
)
|
||||
|
||||
// DeltaType 增量类型枚举
|
||||
type DeltaType string
|
||||
|
||||
const (
|
||||
DeltaTypeText DeltaType = "text_delta"
|
||||
DeltaTypeInputJSON DeltaType = "input_json_delta"
|
||||
DeltaTypeThinking DeltaType = "thinking_delta"
|
||||
)
|
||||
|
||||
// StreamDelta 流式增量联合体
|
||||
type StreamDelta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// StreamContentBlock 流式内容块联合体
|
||||
type StreamContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalStreamEvent 规范流式事件联合体
|
||||
type CanonicalStreamEvent struct {
|
||||
Type StreamEventType `json:"type"`
|
||||
|
||||
// MessageStartEvent
|
||||
Message *StreamMessage `json:"message,omitempty"`
|
||||
|
||||
// ContentBlockStartEvent / ContentBlockDeltaEvent / ContentBlockStopEvent
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *StreamContentBlock `json:"content_block,omitempty"`
|
||||
Delta *StreamDelta `json:"delta,omitempty"`
|
||||
|
||||
// MessageDeltaEvent
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
Usage *CanonicalUsage `json:"usage,omitempty"`
|
||||
|
||||
// ErrorEvent
|
||||
Error *StreamError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// StreamMessage 流式消息摘要
|
||||
type StreamMessage struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Usage *CanonicalUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// StreamError 流式错误
|
||||
type StreamError struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// NewMessageStartEvent 创建消息开始事件
|
||||
func NewMessageStartEvent(id, model string) CanonicalStreamEvent {
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventMessageStart,
|
||||
Message: &StreamMessage{ID: id, Model: model},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMessageStartEventWithUsage 创建带用量的消息开始事件
|
||||
func NewMessageStartEventWithUsage(id, model string, usage *CanonicalUsage) CanonicalStreamEvent {
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventMessageStart,
|
||||
Message: &StreamMessage{ID: id, Model: model, Usage: usage},
|
||||
}
|
||||
}
|
||||
|
||||
// NewContentBlockStartEvent 创建内容块开始事件
|
||||
func NewContentBlockStartEvent(index int, block StreamContentBlock) CanonicalStreamEvent {
|
||||
idx := index
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventContentBlockStart,
|
||||
Index: &idx,
|
||||
ContentBlock: &block,
|
||||
}
|
||||
}
|
||||
|
||||
// NewContentBlockDeltaEvent 创建内容块增量事件
|
||||
func NewContentBlockDeltaEvent(index int, delta StreamDelta) CanonicalStreamEvent {
|
||||
idx := index
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventContentBlockDelta,
|
||||
Index: &idx,
|
||||
Delta: &delta,
|
||||
}
|
||||
}
|
||||
|
||||
// NewContentBlockStopEvent 创建内容块结束事件
|
||||
func NewContentBlockStopEvent(index int) CanonicalStreamEvent {
|
||||
idx := index
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventContentBlockStop,
|
||||
Index: &idx,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMessageDeltaEvent 创建消息增量事件
|
||||
func NewMessageDeltaEvent(stopReason StopReason) CanonicalStreamEvent {
|
||||
sr := stopReason
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventMessageDelta,
|
||||
StopReason: &sr,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMessageDeltaEventWithUsage 创建带用量的消息增量事件
|
||||
func NewMessageDeltaEventWithUsage(stopReason StopReason, usage *CanonicalUsage) CanonicalStreamEvent {
|
||||
sr := stopReason
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventMessageDelta,
|
||||
StopReason: &sr,
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMessageStopEvent 创建消息结束事件
|
||||
func NewMessageStopEvent() CanonicalStreamEvent {
|
||||
return CanonicalStreamEvent{Type: EventMessageStop}
|
||||
}
|
||||
|
||||
// NewErrorEvent 创建错误事件
|
||||
func NewErrorEvent(errType, message string) CanonicalStreamEvent {
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventError,
|
||||
Error: &StreamError{Type: errType, Message: message},
|
||||
}
|
||||
}
|
||||
|
||||
// NewPingEvent 创建心跳事件
|
||||
func NewPingEvent() CanonicalStreamEvent {
|
||||
return CanonicalStreamEvent{Type: EventPing}
|
||||
}
|
||||
208
backend/internal/conversion/canonical/types.go
Normal file
208
backend/internal/conversion/canonical/types.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package canonical
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// MessageRole 消息角色枚举
|
||||
type MessageRole string
|
||||
|
||||
const (
|
||||
RoleSystem MessageRole = "system"
|
||||
RoleUser MessageRole = "user"
|
||||
RoleAssistant MessageRole = "assistant"
|
||||
RoleTool MessageRole = "tool"
|
||||
)
|
||||
|
||||
// StopReason 停止原因枚举
|
||||
type StopReason string
|
||||
|
||||
const (
|
||||
StopReasonEndTurn StopReason = "end_turn"
|
||||
StopReasonMaxTokens StopReason = "max_tokens"
|
||||
StopReasonToolUse StopReason = "tool_use"
|
||||
StopReasonStopSequence StopReason = "stop_sequence"
|
||||
StopReasonContentFilter StopReason = "content_filter"
|
||||
StopReasonRefusal StopReason = "refusal"
|
||||
)
|
||||
|
||||
// SystemBlock 系统消息块
|
||||
type SystemBlock struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// ContentBlock 使用 type 字段的 discriminated union
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// TextBlock
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// ToolUseBlock
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
|
||||
// ToolResultBlock
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
IsError *bool `json:"is_error,omitempty"`
|
||||
|
||||
// ThinkingBlock
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// NewTextBlock 创建文本块
|
||||
func NewTextBlock(text string) ContentBlock {
|
||||
return ContentBlock{Type: "text", Text: text}
|
||||
}
|
||||
|
||||
// NewToolUseBlock 创建工具调用块
|
||||
func NewToolUseBlock(id, name string, input json.RawMessage) ContentBlock {
|
||||
return ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}
|
||||
}
|
||||
|
||||
// NewToolResultBlock 创建工具结果块
|
||||
func NewToolResultBlock(toolUseID string, content string, isError bool) ContentBlock {
|
||||
errFlag := &isError
|
||||
return ContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: json.RawMessage(fmt.Sprintf("%q", content)),
|
||||
IsError: errFlag,
|
||||
}
|
||||
}
|
||||
|
||||
// NewThinkingBlock 创建思考块
|
||||
func NewThinkingBlock(thinking string) ContentBlock {
|
||||
return ContentBlock{Type: "thinking", Thinking: thinking}
|
||||
}
|
||||
|
||||
// CanonicalMessage 规范消息
|
||||
type CanonicalMessage struct {
|
||||
Role MessageRole `json:"role"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
}
|
||||
|
||||
// CanonicalTool 规范工具定义
|
||||
type CanonicalTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"`
|
||||
}
|
||||
|
||||
// ToolChoice 工具选择联合体
|
||||
type ToolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// NewToolChoiceAuto 创建自动工具选择
|
||||
func NewToolChoiceAuto() *ToolChoice {
|
||||
return &ToolChoice{Type: "auto"}
|
||||
}
|
||||
|
||||
// NewToolChoiceNone 创建无工具选择
|
||||
func NewToolChoiceNone() *ToolChoice {
|
||||
return &ToolChoice{Type: "none"}
|
||||
}
|
||||
|
||||
// NewToolChoiceAny 创建任意工具选择
|
||||
func NewToolChoiceAny() *ToolChoice {
|
||||
return &ToolChoice{Type: "any"}
|
||||
}
|
||||
|
||||
// NewToolChoiceNamed 创建指定工具选择
|
||||
func NewToolChoiceNamed(name string) *ToolChoice {
|
||||
return &ToolChoice{Type: "tool", Name: name}
|
||||
}
|
||||
|
||||
// RequestParameters 请求参数
|
||||
type RequestParameters struct {
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
}
|
||||
|
||||
// ThinkingConfig 思考配置
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"`
|
||||
BudgetTokens *int `json:"budget_tokens,omitempty"`
|
||||
Effort string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
// OutputFormat 输出格式联合体
|
||||
type OutputFormat struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Schema json.RawMessage `json:"schema,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalRequest 规范请求
|
||||
type CanonicalRequest struct {
|
||||
Model string `json:"model"`
|
||||
System any `json:"system,omitempty"` // nil, string, or []SystemBlock
|
||||
Messages []CanonicalMessage `json:"messages"`
|
||||
Tools []CanonicalTool `json:"tools,omitempty"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Parameters RequestParameters `json:"parameters"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
OutputFormat *OutputFormat `json:"output_format,omitempty"`
|
||||
ParallelToolUse *bool `json:"parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalUsage 规范用量
|
||||
type CanonicalUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheReadTokens *int `json:"cache_read_tokens,omitempty"`
|
||||
CacheCreationTokens *int `json:"cache_creation_tokens,omitempty"`
|
||||
ReasoningTokens *int `json:"reasoning_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalResponse 规范响应
|
||||
type CanonicalResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
Usage CanonicalUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// GetSystemString 获取系统消息字符串
|
||||
func (r *CanonicalRequest) GetSystemString() string {
|
||||
switch v := r.System.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []SystemBlock:
|
||||
var result string
|
||||
for i, b := range v {
|
||||
if i > 0 {
|
||||
result += "\n\n"
|
||||
}
|
||||
result += b.Text
|
||||
}
|
||||
return result
|
||||
case nil:
|
||||
return ""
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// SetSystemString 设置系统消息字符串
|
||||
func (r *CanonicalRequest) SetSystemString(s string) {
|
||||
if s == "" {
|
||||
r.System = nil
|
||||
} else {
|
||||
r.System = s
|
||||
}
|
||||
}
|
||||
338
backend/internal/conversion/engine.go
Normal file
338
backend/internal/conversion/engine.go
Normal file
@@ -0,0 +1,338 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// HTTPRequestSpec HTTP 请求规格
|
||||
type HTTPRequestSpec struct {
|
||||
URL string `json:"url"`
|
||||
Method string `json:"method"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body []byte `json:"body"`
|
||||
}
|
||||
|
||||
// HTTPResponseSpec HTTP 响应规格
|
||||
type HTTPResponseSpec struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body []byte `json:"body"`
|
||||
}
|
||||
|
||||
// ConversionEngine 转换引擎门面
|
||||
type ConversionEngine struct {
|
||||
registry AdapterRegistry
|
||||
middlewareChain *MiddlewareChain
|
||||
}
|
||||
|
||||
// NewConversionEngine 创建转换引擎
|
||||
func NewConversionEngine(registry AdapterRegistry) *ConversionEngine {
|
||||
return &ConversionEngine{
|
||||
registry: registry,
|
||||
middlewareChain: NewMiddlewareChain(),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterAdapter 注册协议适配器
|
||||
func (e *ConversionEngine) RegisterAdapter(adapter ProtocolAdapter) error {
|
||||
return e.registry.Register(adapter)
|
||||
}
|
||||
|
||||
// GetRegistry 返回注册表(供外部使用)
|
||||
func (e *ConversionEngine) GetRegistry() AdapterRegistry {
|
||||
return e.registry
|
||||
}
|
||||
|
||||
// Use 添加中间件
|
||||
func (e *ConversionEngine) Use(mw ConversionMiddleware) {
|
||||
e.middlewareChain.Use(mw)
|
||||
}
|
||||
|
||||
// IsPassthrough 判断是否同协议透传
|
||||
func (e *ConversionEngine) IsPassthrough(clientProtocol, providerProtocol string) bool {
|
||||
if clientProtocol != providerProtocol {
|
||||
return false
|
||||
}
|
||||
adapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return adapter.SupportsPassthrough()
|
||||
}
|
||||
|
||||
// ConvertHttpRequest 转换 HTTP 请求
|
||||
func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) {
|
||||
nativePath := spec.URL
|
||||
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
providerAdapter, err := e.registry.Get(providerProtocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &HTTPRequestSpec{
|
||||
URL: provider.BaseURL + nativePath,
|
||||
Method: spec.Method,
|
||||
Headers: providerAdapter.BuildHeaders(provider),
|
||||
Body: spec.Body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
clientAdapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("未找到客户端适配器 %s: %w", clientProtocol, err)
|
||||
}
|
||||
providerAdapter, err := e.registry.Get(providerProtocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("未找到服务端适配器 %s: %w", providerProtocol, err)
|
||||
}
|
||||
|
||||
interfaceType := clientAdapter.DetectInterfaceType(nativePath)
|
||||
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||
providerHeaders := providerAdapter.BuildHeaders(provider)
|
||||
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &HTTPRequestSpec{
|
||||
URL: provider.BaseURL + providerUrl,
|
||||
Method: spec.Method,
|
||||
Headers: providerHeaders,
|
||||
Body: providerBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ConvertHttpResponse 转换 HTTP 响应
|
||||
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType) (*HTTPResponseSpec, error) {
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
return &spec, nil
|
||||
}
|
||||
|
||||
clientAdapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providerAdapter, err := e.registry.Get(providerProtocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &HTTPResponseSpec{
|
||||
StatusCode: spec.StatusCode,
|
||||
Headers: spec.Headers,
|
||||
Body: convertedBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateStreamConverter 创建流式转换器
|
||||
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string) (StreamConverter, error) {
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
return NewPassthroughStreamConverter(), nil
|
||||
}
|
||||
|
||||
providerAdapter, err := e.registry.Get(providerProtocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientAdapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx := ConversionContext{
|
||||
ConversionID: uuid.New().String(),
|
||||
InterfaceType: InterfaceTypeChat,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
return NewCanonicalStreamConverterWithMiddleware(
|
||||
providerAdapter.CreateStreamDecoder(),
|
||||
clientAdapter.CreateStreamEncoder(),
|
||||
e.middlewareChain,
|
||||
ctx,
|
||||
clientProtocol,
|
||||
providerProtocol,
|
||||
), nil
|
||||
}
|
||||
|
||||
// convertBody 转换请求体
|
||||
func (e *ConversionEngine) convertBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
switch interfaceType {
|
||||
case InterfaceTypeChat:
|
||||
return e.convertChatBody(clientAdapter, providerAdapter, provider, body)
|
||||
case InterfaceTypeModels, InterfaceTypeModelInfo:
|
||||
return body, nil
|
||||
case InterfaceTypeEmbeddings:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertEmbeddingBody(clientAdapter, providerAdapter, provider, body)
|
||||
case InterfaceTypeRerank:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertRerankBody(clientAdapter, providerAdapter, provider, body)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
// convertResponseBody 转换响应体
|
||||
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
switch interfaceType {
|
||||
case InterfaceTypeChat:
|
||||
return e.convertChatResponseBody(clientAdapter, providerAdapter, body)
|
||||
case InterfaceTypeModels:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertModelsResponseBody(clientAdapter, providerAdapter, body)
|
||||
case InterfaceTypeModelInfo:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeModelInfo) || !providerAdapter.SupportsInterface(InterfaceTypeModelInfo) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertModelInfoResponseBody(clientAdapter, providerAdapter, body)
|
||||
case InterfaceTypeEmbeddings:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body)
|
||||
case InterfaceTypeRerank:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
canonicalReq, err := clientAdapter.DecodeRequest(body)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeJSONParseError, "解码请求失败").WithCause(err)
|
||||
}
|
||||
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
canonicalReq, err = e.middlewareChain.Apply(canonicalReq, clientAdapter.ProtocolName(), providerAdapter.ProtocolName(), ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码请求失败").WithCause(err)
|
||||
}
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
canonicalResp, err := providerAdapter.DecodeResponse(body)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeResponse(canonicalResp)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码响应失败").WithCause(err)
|
||||
}
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
models, err := providerAdapter.DecodeModelsResponse(body)
|
||||
if err != nil {
|
||||
zap.L().Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
return body, nil
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeModelsResponse(models)
|
||||
if err != nil {
|
||||
zap.L().Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
return body, nil
|
||||
}
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
info, err := providerAdapter.DecodeModelInfoResponse(body)
|
||||
if err != nil {
|
||||
zap.L().Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
return body, nil
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeModelInfoResponse(info)
|
||||
if err != nil {
|
||||
zap.L().Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
return body, nil
|
||||
}
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
req, err := clientAdapter.DecodeEmbeddingRequest(body)
|
||||
if err != nil {
|
||||
zap.L().Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error()))
|
||||
return body, nil
|
||||
}
|
||||
return providerAdapter.EncodeEmbeddingRequest(req, provider)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
|
||||
if err != nil {
|
||||
zap.L().Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
return body, nil
|
||||
}
|
||||
return clientAdapter.EncodeEmbeddingResponse(resp)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
req, err := clientAdapter.DecodeRerankRequest(body)
|
||||
if err != nil {
|
||||
zap.L().Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error()))
|
||||
return body, nil
|
||||
}
|
||||
return providerAdapter.EncodeRerankRequest(req, provider)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
resp, err := providerAdapter.DecodeRerankResponse(body)
|
||||
if err != nil {
|
||||
return body, nil
|
||||
}
|
||||
return clientAdapter.EncodeRerankResponse(resp)
|
||||
}
|
||||
|
||||
// DetectInterfaceType 检测接口类型
|
||||
func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string) (InterfaceType, error) {
|
||||
adapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return InterfaceTypePassthrough, err
|
||||
}
|
||||
return adapter.DetectInterfaceType(nativePath), nil
|
||||
}
|
||||
|
||||
// EncodeError 使用客户端适配器编码错误
|
||||
func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol string) ([]byte, int, error) {
|
||||
adapter, adapterErr := e.registry.Get(clientProtocol)
|
||||
if adapterErr != nil {
|
||||
fallback := map[string]any{
|
||||
"error": map[string]string{
|
||||
"message": err.Error(),
|
||||
"type": "internal_error",
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(fallback)
|
||||
return body, 500, nil
|
||||
}
|
||||
body, statusCode := adapter.EncodeError(err)
|
||||
return body, statusCode, nil
|
||||
}
|
||||
366
backend/internal/conversion/engine_test.go
Normal file
366
backend/internal/conversion/engine_test.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockProtocolAdapter 模拟协议适配器
|
||||
type mockProtocolAdapter struct {
|
||||
protocolName string
|
||||
passthrough bool
|
||||
ifaceType InterfaceType
|
||||
supportsIface map[InterfaceType]bool
|
||||
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
|
||||
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
|
||||
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
|
||||
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
|
||||
streamDecoderFn func() StreamDecoder
|
||||
streamEncoderFn func() StreamEncoder
|
||||
}
|
||||
|
||||
func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
|
||||
return &mockProtocolAdapter{
|
||||
protocolName: name,
|
||||
passthrough: passthrough,
|
||||
ifaceType: InterfaceTypeChat,
|
||||
supportsIface: map[InterfaceType]bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName }
|
||||
func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" }
|
||||
func (m *mockProtocolAdapter) SupportsPassthrough() bool { return m.passthrough }
|
||||
|
||||
func (m *mockProtocolAdapter) DetectInterfaceType(nativePath string) InterfaceType {
|
||||
return m.ifaceType
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) BuildUrl(nativePath string, interfaceType InterfaceType) string {
|
||||
return nativePath
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) BuildHeaders(provider *TargetProvider) map[string]string {
|
||||
return map[string]string{"Authorization": "Bearer " + provider.APIKey}
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) SupportsInterface(interfaceType InterfaceType) bool {
|
||||
if v, ok := m.supportsIface[interfaceType]; ok {
|
||||
return v
|
||||
}
|
||||
return interfaceType == InterfaceTypeChat
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
if m.decodeReqFn != nil {
|
||||
return m.decodeReqFn(raw)
|
||||
}
|
||||
req := &canonical.CanonicalRequest{}
|
||||
_ = json.Unmarshal(raw, req)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeRequest(req *canonical.CanonicalRequest, provider *TargetProvider) ([]byte, error) {
|
||||
if m.encodeReqFn != nil {
|
||||
return m.encodeReqFn(req, provider)
|
||||
}
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
if m.decodeRespFn != nil {
|
||||
return m.decodeRespFn(raw)
|
||||
}
|
||||
resp := &canonical.CanonicalResponse{}
|
||||
_ = json.Unmarshal(raw, resp)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
if m.encodeRespFn != nil {
|
||||
return m.encodeRespFn(resp)
|
||||
}
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) CreateStreamDecoder() StreamDecoder {
|
||||
if m.streamDecoderFn != nil {
|
||||
return m.streamDecoderFn()
|
||||
}
|
||||
return &noopStreamDecoder{}
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) CreateStreamEncoder() StreamEncoder {
|
||||
if m.streamEncoderFn != nil {
|
||||
return m.streamEncoderFn()
|
||||
}
|
||||
return &noopStreamEncoder{}
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeError(err *ConversionError) ([]byte, int) {
|
||||
return []byte(`{"error":"mock"}`), 400
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) {
|
||||
return &canonical.CanonicalModelList{}, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
|
||||
return json.Marshal(list)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) {
|
||||
return &canonical.CanonicalModelInfo{}, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
|
||||
return json.Marshal(info)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
return &canonical.CanonicalEmbeddingRequest{}, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *TargetProvider) ([]byte, error) {
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) {
|
||||
return &canonical.CanonicalEmbeddingResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) {
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
return &canonical.CanonicalRerankRequest{}, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error) {
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) {
|
||||
return &canonical.CanonicalRerankResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
// noopStreamDecoder 空流式解码器
|
||||
type noopStreamDecoder struct{}
|
||||
|
||||
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil }
|
||||
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
|
||||
|
||||
// noopStreamEncoder 空流式编码器
|
||||
type noopStreamEncoder struct{}
|
||||
|
||||
func (e *noopStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { return nil }
|
||||
func (e *noopStreamEncoder) Flush() [][]byte { return nil }
|
||||
|
||||
// ============ 测试用例 ============
|
||||
|
||||
func TestNewConversionEngine(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry)
|
||||
assert.NotNil(t, engine)
|
||||
assert.Equal(t, registry, engine.GetRegistry())
|
||||
}
|
||||
|
||||
func TestRegisterAdapter(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry)
|
||||
|
||||
adapter := newMockAdapter("test-proto", true)
|
||||
err := engine.RegisterAdapter(adapter)
|
||||
require.NoError(t, err)
|
||||
|
||||
protocols := registry.ListProtocols()
|
||||
assert.Contains(t, protocols, "test-proto")
|
||||
}
|
||||
|
||||
func TestIsPassthrough_SameProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry)
|
||||
adapter := newMockAdapter("openai", true)
|
||||
_ = engine.RegisterAdapter(adapter)
|
||||
|
||||
assert.True(t, engine.IsPassthrough("openai", "openai"))
|
||||
}
|
||||
|
||||
func TestIsPassthrough_DifferentProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry)
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
_ = engine.RegisterAdapter(newMockAdapter("anthropic", true))
|
||||
|
||||
assert.False(t, engine.IsPassthrough("openai", "anthropic"))
|
||||
}
|
||||
|
||||
func TestIsPassthrough_NoPassthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry)
|
||||
_ = engine.RegisterAdapter(newMockAdapter("custom", false))
|
||||
|
||||
assert.False(t, engine.IsPassthrough("custom", "custom"))
|
||||
}
|
||||
|
||||
func TestDetectInterfaceType(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry)
|
||||
adapter := newMockAdapter("test", true)
|
||||
adapter.ifaceType = InterfaceTypeChat
|
||||
_ = engine.RegisterAdapter(adapter)
|
||||
|
||||
ifaceType, err := engine.DetectInterfaceType("/v1/chat/completions", "test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, InterfaceTypeChat, ifaceType)
|
||||
}
|
||||
|
||||
func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry)
|
||||
|
||||
_, err := engine.DetectInterfaceType("/v1/chat", "nonexistent")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry)
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
provider := NewTargetProvider("https://api.openai.com", "sk-test", "gpt-4")
|
||||
spec := HTTPRequestSpec{
|
||||
URL: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpRequest(spec, "openai", "openai", provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
|
||||
assert.Equal(t, spec.Body, result.Body)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry)
|
||||
|
||||
clientAdapter := newMockAdapter("client-proto", false)
|
||||
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
return &canonical.CanonicalRequest{
|
||||
Model: "test-model",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
}, nil
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
|
||||
providerAdapter := newMockAdapter("provider-proto", false)
|
||||
providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) {
|
||||
return json.Marshal(map[string]any{"model": p.ModelName})
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
provider := NewTargetProvider("https://example.com", "key", "my-model")
|
||||
spec := HTTPRequestSpec{
|
||||
URL: "/v1/chat",
|
||||
Method: "POST",
|
||||
Body: []byte(`{"model":"test"}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpRequest(spec, "client-proto", "provider-proto", provider)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, result.URL, "https://example.com")
|
||||
assert.NotNil(t, result.Body)
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry)
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
spec := HTTPResponseSpec{
|
||||
StatusCode: 200,
|
||||
Body: []byte(`{"id":"123"}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 200, result.StatusCode)
|
||||
assert.Equal(t, spec.Body, result.Body)
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry)
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
converter, err := engine.CreateStreamConverter("openai", "openai")
|
||||
require.NoError(t, err)
|
||||
_, ok := converter.(*PassthroughStreamConverter)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_Canonical(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry)
|
||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
|
||||
|
||||
converter, err := engine.CreateStreamConverter("client", "provider")
|
||||
require.NoError(t, err)
|
||||
_, ok := converter.(*CanonicalStreamConverter)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestEncodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry)
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
|
||||
body, statusCode, err := engine.EncodeError(convErr, "openai")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 400, statusCode)
|
||||
assert.NotNil(t, body)
|
||||
}
|
||||
|
||||
func TestEncodeError_NonExistentProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry)
|
||||
|
||||
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
|
||||
body, statusCode, err := engine.EncodeError(convErr, "nonexistent")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 500, statusCode)
|
||||
assert.Contains(t, string(body), "测试错误")
|
||||
}
|
||||
|
||||
func TestRegistry_DuplicateRegistration(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
adapter := newMockAdapter("openai", true)
|
||||
|
||||
err := registry.Register(adapter)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = registry.Register(adapter)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "适配器已注册")
|
||||
}
|
||||
|
||||
func TestRegistry_GetNonExistent(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
|
||||
_, err := registry.Get("nonexistent")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "未找到适配器")
|
||||
}
|
||||
83
backend/internal/conversion/errors.go
Normal file
83
backend/internal/conversion/errors.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package conversion
|
||||
|
||||
import "fmt"
|
||||
|
||||
// ErrorCode 错误码枚举
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
|
||||
ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD"
|
||||
ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE"
|
||||
ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE"
|
||||
ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR"
|
||||
ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR"
|
||||
ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR"
|
||||
ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR"
|
||||
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
|
||||
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
|
||||
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
|
||||
)
|
||||
|
||||
// ConversionError 协议转换错误
|
||||
type ConversionError struct {
|
||||
Code ErrorCode
|
||||
Message string
|
||||
ClientProtocol string
|
||||
ProviderProtocol string
|
||||
InterfaceType string
|
||||
Details map[string]any
|
||||
Cause error
|
||||
}
|
||||
|
||||
// NewConversionError 创建转换错误
|
||||
func NewConversionError(code ErrorCode, message string) *ConversionError {
|
||||
return &ConversionError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
// WithClientProtocol 设置客户端协议
|
||||
func (e *ConversionError) WithClientProtocol(protocol string) *ConversionError {
|
||||
e.ClientProtocol = protocol
|
||||
return e
|
||||
}
|
||||
|
||||
// WithProviderProtocol 设置服务端协议
|
||||
func (e *ConversionError) WithProviderProtocol(protocol string) *ConversionError {
|
||||
e.ProviderProtocol = protocol
|
||||
return e
|
||||
}
|
||||
|
||||
// WithInterfaceType 设置接口类型
|
||||
func (e *ConversionError) WithInterfaceType(ifaceType string) *ConversionError {
|
||||
e.InterfaceType = ifaceType
|
||||
return e
|
||||
}
|
||||
|
||||
// WithDetail 添加详情
|
||||
func (e *ConversionError) WithDetail(key string, value any) *ConversionError {
|
||||
e.Details[key] = value
|
||||
return e
|
||||
}
|
||||
|
||||
// WithCause 设置原因
|
||||
func (e *ConversionError) WithCause(cause error) *ConversionError {
|
||||
e.Cause = cause
|
||||
return e
|
||||
}
|
||||
|
||||
// Error 实现 error 接口
|
||||
func (e *ConversionError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("[%s] %s: %v", e.Code, e.Message, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("[%s] %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// Unwrap 支持 errors.Is/As
|
||||
func (e *ConversionError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
45
backend/internal/conversion/errors_test.go
Normal file
45
backend/internal/conversion/errors_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConversionError_Builder(t *testing.T) {
|
||||
cause := errors.New("原始错误")
|
||||
err := NewConversionError(ErrorCodeInvalidInput, "输入无效").
|
||||
WithClientProtocol("openai").
|
||||
WithDetail("field", "model").
|
||||
WithCause(cause)
|
||||
|
||||
assert.Equal(t, ErrorCodeInvalidInput, err.Code)
|
||||
assert.Equal(t, "openai", err.ClientProtocol)
|
||||
assert.Equal(t, "输入无效", err.Message)
|
||||
assert.Equal(t, "model", err.Details["field"])
|
||||
assert.Equal(t, cause, err.Cause)
|
||||
}
|
||||
|
||||
func TestConversionError_Unwrap(t *testing.T) {
|
||||
cause := errors.New("根本原因")
|
||||
err := NewConversionError(ErrorCodeJSONParseError, "解析失败").WithCause(cause)
|
||||
|
||||
unwrapped := err.Unwrap()
|
||||
assert.Equal(t, cause, unwrapped)
|
||||
}
|
||||
|
||||
func TestConversionError_Error_WithCause(t *testing.T) {
|
||||
err := NewConversionError(ErrorCodeInvalidInput, "输入无效").WithCause(errors.New("原因"))
|
||||
msg := err.Error()
|
||||
assert.Contains(t, msg, "INVALID_INPUT")
|
||||
assert.Contains(t, msg, "输入无效")
|
||||
assert.Contains(t, msg, "原因")
|
||||
}
|
||||
|
||||
func TestConversionError_Error_WithoutCause(t *testing.T) {
|
||||
err := NewConversionError(ErrorCodeInvalidInput, "输入无效")
|
||||
msg := err.Error()
|
||||
assert.Contains(t, msg, "INVALID_INPUT")
|
||||
assert.Contains(t, msg, "输入无效")
|
||||
}
|
||||
13
backend/internal/conversion/interface.go
Normal file
13
backend/internal/conversion/interface.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package conversion
|
||||
|
||||
// InterfaceType 接口类型枚举
|
||||
type InterfaceType string
|
||||
|
||||
const (
|
||||
InterfaceTypeChat InterfaceType = "CHAT"
|
||||
InterfaceTypeModels InterfaceType = "MODELS"
|
||||
InterfaceTypeModelInfo InterfaceType = "MODEL_INFO"
|
||||
InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS"
|
||||
InterfaceTypeRerank InterfaceType = "RERANK"
|
||||
InterfaceTypePassthrough InterfaceType = "PASSTHROUGH"
|
||||
)
|
||||
76
backend/internal/conversion/middleware.go
Normal file
76
backend/internal/conversion/middleware.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ConversionMiddleware 转换中间件接口
|
||||
type ConversionMiddleware interface {
|
||||
Intercept(req *canonical.CanonicalRequest, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalRequest, error)
|
||||
InterceptStreamEvent(event *canonical.CanonicalStreamEvent, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalStreamEvent, error)
|
||||
}
|
||||
|
||||
// ConversionContext 转换上下文
|
||||
type ConversionContext struct {
|
||||
ConversionID string
|
||||
InterfaceType InterfaceType
|
||||
Timestamp time.Time
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// NewConversionContext 创建转换上下文
|
||||
func NewConversionContext(ifaceType InterfaceType) *ConversionContext {
|
||||
return &ConversionContext{
|
||||
ConversionID: uuid.New().String(),
|
||||
InterfaceType: ifaceType,
|
||||
Timestamp: time.Now().UTC(),
|
||||
Metadata: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
// MiddlewareChain 中间件链
|
||||
type MiddlewareChain struct {
|
||||
middlewares []ConversionMiddleware
|
||||
}
|
||||
|
||||
// NewMiddlewareChain 创建中间件链
|
||||
func NewMiddlewareChain() *MiddlewareChain {
|
||||
return &MiddlewareChain{
|
||||
middlewares: make([]ConversionMiddleware, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Use 添加中间件
|
||||
func (c *MiddlewareChain) Use(mw ConversionMiddleware) {
|
||||
c.middlewares = append(c.middlewares, mw)
|
||||
}
|
||||
|
||||
// Apply 对请求按顺序执行所有中间件
|
||||
func (c *MiddlewareChain) Apply(req *canonical.CanonicalRequest, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
|
||||
result := req
|
||||
for _, mw := range c.middlewares {
|
||||
var err error
|
||||
result, err = mw.Intercept(result, clientProtocol, providerProtocol, ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ApplyStreamEvent 对流式事件按顺序执行所有中间件
|
||||
func (c *MiddlewareChain) ApplyStreamEvent(event *canonical.CanonicalStreamEvent, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalStreamEvent, error) {
|
||||
result := event
|
||||
for _, mw := range c.middlewares {
|
||||
var err error
|
||||
result, err = mw.InterceptStreamEvent(result, clientProtocol, providerProtocol, ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
85
backend/internal/conversion/middleware_test.go
Normal file
85
backend/internal/conversion/middleware_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// recordingMiddleware 记录调用顺序的中间件
|
||||
type recordingMiddleware struct {
|
||||
name string
|
||||
records *[]string
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *recordingMiddleware) Intercept(req *canonical.CanonicalRequest, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
|
||||
*m.records = append(*m.records, m.name)
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (m *recordingMiddleware) InterceptStreamEvent(event *canonical.CanonicalStreamEvent, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalStreamEvent, error) {
|
||||
*m.records = append(*m.records, "stream:"+m.name)
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return event, nil
|
||||
}
|
||||
|
||||
func TestMiddlewareChain_Empty(t *testing.T) {
|
||||
chain := NewMiddlewareChain()
|
||||
req := &canonical.CanonicalRequest{Model: "test"}
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
|
||||
result, err := chain.Apply(req, "a", "b", ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test", result.Model)
|
||||
}
|
||||
|
||||
func TestMiddlewareChain_Order(t *testing.T) {
|
||||
var records []string
|
||||
chain := NewMiddlewareChain()
|
||||
chain.Use(&recordingMiddleware{name: "first", records: &records})
|
||||
chain.Use(&recordingMiddleware{name: "second", records: &records})
|
||||
chain.Use(&recordingMiddleware{name: "third", records: &records})
|
||||
|
||||
req := &canonical.CanonicalRequest{Model: "test"}
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
_, err := chain.Apply(req, "a", "b", ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"first", "second", "third"}, records)
|
||||
}
|
||||
|
||||
func TestMiddlewareChain_ErrorInterrupt(t *testing.T) {
|
||||
var records []string
|
||||
chain := NewMiddlewareChain()
|
||||
chain.Use(&recordingMiddleware{name: "first", records: &records})
|
||||
chain.Use(&recordingMiddleware{name: "second", records: &records, err: errors.New("中断")})
|
||||
chain.Use(&recordingMiddleware{name: "third", records: &records})
|
||||
|
||||
req := &canonical.CanonicalRequest{Model: "test"}
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
_, err := chain.Apply(req, "a", "b", ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "中断", err.Error())
|
||||
assert.Equal(t, []string{"first", "second"}, records)
|
||||
}
|
||||
|
||||
func TestMiddlewareChain_ApplyStreamEvent(t *testing.T) {
|
||||
var records []string
|
||||
chain := NewMiddlewareChain()
|
||||
chain.Use(&recordingMiddleware{name: "mw1", records: &records})
|
||||
|
||||
event := canonical.NewMessageStartEvent("id", "model")
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
result, err := chain.ApplyStreamEvent(&event, "a", "b", ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, canonical.EventMessageStart, result.Type)
|
||||
assert.Equal(t, []string{"stream:mw1"}, records)
|
||||
}
|
||||
211
backend/internal/conversion/openai/adapter.go
Normal file
211
backend/internal/conversion/openai/adapter.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// Adapter OpenAI 协议适配器
|
||||
type Adapter struct{}
|
||||
|
||||
// NewAdapter 创建 OpenAI 适配器
|
||||
func NewAdapter() *Adapter {
|
||||
return &Adapter{}
|
||||
}
|
||||
|
||||
var modelInfoRegex = regexp.MustCompile(`^/v1/models/[^/]+$`)
|
||||
|
||||
// ProtocolName 返回协议名称
|
||||
func (a *Adapter) ProtocolName() string { return "openai" }
|
||||
|
||||
// ProtocolVersion 返回协议版本
|
||||
func (a *Adapter) ProtocolVersion() string { return "" }
|
||||
|
||||
// SupportsPassthrough 支持同协议透传
|
||||
func (a *Adapter) SupportsPassthrough() bool { return true }
|
||||
|
||||
// DetectInterfaceType 根据路径检测接口类型
|
||||
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
|
||||
switch {
|
||||
case nativePath == "/v1/chat/completions":
|
||||
return conversion.InterfaceTypeChat
|
||||
case nativePath == "/v1/models":
|
||||
return conversion.InterfaceTypeModels
|
||||
case modelInfoRegex.MatchString(nativePath):
|
||||
return conversion.InterfaceTypeModelInfo
|
||||
case nativePath == "/v1/embeddings":
|
||||
return conversion.InterfaceTypeEmbeddings
|
||||
case nativePath == "/v1/rerank":
|
||||
return conversion.InterfaceTypeRerank
|
||||
default:
|
||||
return conversion.InterfaceTypePassthrough
|
||||
}
|
||||
}
|
||||
|
||||
// BuildUrl 根据接口类型构建 URL
|
||||
func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string {
|
||||
switch interfaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
return "/v1/chat/completions"
|
||||
case conversion.InterfaceTypeModels:
|
||||
return "/v1/models"
|
||||
case conversion.InterfaceTypeEmbeddings:
|
||||
return "/v1/embeddings"
|
||||
case conversion.InterfaceTypeRerank:
|
||||
return "/v1/rerank"
|
||||
default:
|
||||
return nativePath
|
||||
}
|
||||
}
|
||||
|
||||
// BuildHeaders 构建请求头
|
||||
func (a *Adapter) BuildHeaders(provider *conversion.TargetProvider) map[string]string {
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + provider.APIKey,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if org, ok := provider.AdapterConfig["organization"].(string); ok && org != "" {
|
||||
headers["OpenAI-Organization"] = org
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// SupportsInterface 检查是否支持接口类型
|
||||
func (a *Adapter) SupportsInterface(interfaceType conversion.InterfaceType) bool {
|
||||
switch interfaceType {
|
||||
case conversion.InterfaceTypeChat,
|
||||
conversion.InterfaceTypeModels,
|
||||
conversion.InterfaceTypeModelInfo,
|
||||
conversion.InterfaceTypeEmbeddings,
|
||||
conversion.InterfaceTypeRerank:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeRequest 解码请求
|
||||
func (a *Adapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
return decodeRequest(raw)
|
||||
}
|
||||
|
||||
// EncodeRequest 编码请求
|
||||
func (a *Adapter) EncodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return encodeRequest(req, provider)
|
||||
}
|
||||
|
||||
// DecodeResponse 解码响应
|
||||
func (a *Adapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
return decodeResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeResponse 编码响应
|
||||
func (a *Adapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
return encodeResponse(resp)
|
||||
}
|
||||
|
||||
// CreateStreamDecoder 创建流式解码器
|
||||
func (a *Adapter) CreateStreamDecoder() conversion.StreamDecoder {
|
||||
return NewStreamDecoder()
|
||||
}
|
||||
|
||||
// CreateStreamEncoder 创建流式编码器
|
||||
func (a *Adapter) CreateStreamEncoder() conversion.StreamEncoder {
|
||||
return NewStreamEncoder()
|
||||
}
|
||||
|
||||
// EncodeError 编码错误
|
||||
func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
||||
errType := mapErrorCode(err.Code)
|
||||
statusCode := 500
|
||||
|
||||
errMsg := ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: err.Message,
|
||||
Type: errType,
|
||||
Param: nil,
|
||||
Code: string(err.Code),
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(errMsg)
|
||||
return body, statusCode
|
||||
}
|
||||
|
||||
// mapErrorCode 映射错误码到 OpenAI 错误类型
|
||||
func mapErrorCode(code conversion.ErrorCode) string {
|
||||
switch code {
|
||||
case conversion.ErrorCodeInvalidInput,
|
||||
conversion.ErrorCodeMissingRequiredField,
|
||||
conversion.ErrorCodeIncompatibleFeature,
|
||||
conversion.ErrorCodeToolCallParseError,
|
||||
conversion.ErrorCodeJSONParseError,
|
||||
conversion.ErrorCodeProtocolConstraint,
|
||||
conversion.ErrorCodeFieldMappingFailure:
|
||||
return "invalid_request_error"
|
||||
default:
|
||||
return "server_error"
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeModelsResponse 解码模型列表响应
|
||||
func (a *Adapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) {
|
||||
return decodeModelsResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeModelsResponse 编码模型列表响应
|
||||
func (a *Adapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
|
||||
return encodeModelsResponse(list)
|
||||
}
|
||||
|
||||
// DecodeModelInfoResponse 解码模型详情响应
|
||||
func (a *Adapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) {
|
||||
return decodeModelInfoResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeModelInfoResponse 编码模型详情响应
|
||||
func (a *Adapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
|
||||
return encodeModelInfoResponse(info)
|
||||
}
|
||||
|
||||
// DecodeEmbeddingRequest 解码嵌入请求
|
||||
func (a *Adapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
return decodeEmbeddingRequest(raw)
|
||||
}
|
||||
|
||||
// EncodeEmbeddingRequest 编码嵌入请求
|
||||
func (a *Adapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return encodeEmbeddingRequest(req, provider)
|
||||
}
|
||||
|
||||
// DecodeEmbeddingResponse 解码嵌入响应
|
||||
func (a *Adapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) {
|
||||
return decodeEmbeddingResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeEmbeddingResponse 编码嵌入响应
|
||||
func (a *Adapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) {
|
||||
return encodeEmbeddingResponse(resp)
|
||||
}
|
||||
|
||||
// DecodeRerankRequest 解码重排序请求
|
||||
func (a *Adapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
return decodeRerankRequest(raw)
|
||||
}
|
||||
|
||||
// EncodeRerankRequest 编码重排序请求
|
||||
func (a *Adapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return encodeRerankRequest(req, provider)
|
||||
}
|
||||
|
||||
// DecodeRerankResponse 解码重排序响应
|
||||
func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) {
|
||||
return decodeRerankResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeRerankResponse 编码重排序响应
|
||||
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
return encodeRerankResponse(resp)
|
||||
}
|
||||
139
backend/internal/conversion/openai/adapter_test.go
Normal file
139
backend/internal/conversion/openai/adapter_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAdapter_ProtocolName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
assert.Equal(t, "openai", a.ProtocolName())
|
||||
}
|
||||
|
||||
func TestAdapter_SupportsPassthrough(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
assert.True(t, a.SupportsPassthrough())
|
||||
}
|
||||
|
||||
func TestAdapter_DetectInterfaceType(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"聊天补全", "/v1/chat/completions", conversion.InterfaceTypeChat},
|
||||
{"模型列表", "/v1/models", conversion.InterfaceTypeModels},
|
||||
{"模型详情", "/v1/models/gpt-4", conversion.InterfaceTypeModelInfo},
|
||||
{"嵌入接口", "/v1/embeddings", conversion.InterfaceTypeEmbeddings},
|
||||
{"重排序接口", "/v1/rerank", conversion.InterfaceTypeRerank},
|
||||
{"未知路径", "/v1/unknown", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.DetectInterfaceType(tt.path)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nativePath string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected string
|
||||
}{
|
||||
{"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/v1/chat/completions"},
|
||||
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
|
||||
{"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/v1/embeddings"},
|
||||
{"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/v1/rerank"},
|
||||
{"默认透传", "/v1/other", conversion.InterfaceTypePassthrough, "/v1/other"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.BuildUrl(tt.nativePath, tt.interfaceType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildHeaders(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("基本头", func(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.openai.com", "sk-test123", "gpt-4")
|
||||
headers := a.BuildHeaders(provider)
|
||||
assert.Equal(t, "Bearer sk-test123", headers["Authorization"])
|
||||
assert.Equal(t, "application/json", headers["Content-Type"])
|
||||
_, hasOrg := headers["OpenAI-Organization"]
|
||||
assert.False(t, hasOrg)
|
||||
})
|
||||
|
||||
t.Run("带组织", func(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.openai.com", "sk-test123", "gpt-4")
|
||||
provider.AdapterConfig["organization"] = "org-abc"
|
||||
headers := a.BuildHeaders(provider)
|
||||
assert.Equal(t, "org-abc", headers["OpenAI-Organization"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdapter_SupportsInterface(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected bool
|
||||
}{
|
||||
{"聊天", conversion.InterfaceTypeChat, true},
|
||||
{"模型", conversion.InterfaceTypeModels, true},
|
||||
{"模型详情", conversion.InterfaceTypeModelInfo, true},
|
||||
{"嵌入", conversion.InterfaceTypeEmbeddings, true},
|
||||
{"重排序", conversion.InterfaceTypeRerank, true},
|
||||
{"透传", conversion.InterfaceTypePassthrough, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.SupportsInterface(tt.interfaceType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_EncodeError_InvalidInput(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")
|
||||
|
||||
body, statusCode := a.EncodeError(convErr)
|
||||
require.Equal(t, 500, statusCode)
|
||||
|
||||
var resp ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(body, &resp))
|
||||
assert.Equal(t, "参数无效", resp.Error.Message)
|
||||
assert.Equal(t, "invalid_request_error", resp.Error.Type)
|
||||
}
|
||||
|
||||
func TestAdapter_EncodeError_ServerError(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
convErr := conversion.NewConversionError(conversion.ErrorCodeStreamStateError, "流状态错误")
|
||||
|
||||
body, statusCode := a.EncodeError(convErr)
|
||||
require.Equal(t, 500, statusCode)
|
||||
|
||||
var resp ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(body, &resp))
|
||||
assert.Equal(t, "server_error", resp.Error.Type)
|
||||
assert.Equal(t, "流状态错误", resp.Error.Message)
|
||||
}
|
||||
669
backend/internal/conversion/openai/decoder.go
Normal file
669
backend/internal/conversion/openai/decoder.go
Normal file
@@ -0,0 +1,669 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// decodeRequest 将 OpenAI 请求解码为 Canonical 请求
|
||||
func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
|
||||
var req ChatCompletionRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 OpenAI 请求失败").WithCause(err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(req.Model) == "" {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "model 字段不能为空")
|
||||
}
|
||||
if len(req.Messages) == 0 {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "messages 字段不能为空")
|
||||
}
|
||||
|
||||
// 废弃字段兼容
|
||||
decodeDeprecatedFields(&req)
|
||||
|
||||
system, messages := decodeSystemPrompt(req.Messages)
|
||||
|
||||
var canonicalMsgs []canonical.CanonicalMessage
|
||||
for _, msg := range messages {
|
||||
decoded, err := decodeMessage(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
canonicalMsgs = append(canonicalMsgs, decoded...)
|
||||
}
|
||||
|
||||
tools := decodeTools(req.Tools)
|
||||
toolChoice := decodeToolChoice(req.ToolChoice)
|
||||
params := decodeParameters(&req)
|
||||
outputFormat := decodeOutputFormat(req.ResponseFormat)
|
||||
thinking := decodeThinking(req.ReasoningEffort)
|
||||
|
||||
var parallelToolUse *bool
|
||||
if req.ParallelToolCalls != nil {
|
||||
parallelToolUse = req.ParallelToolCalls
|
||||
}
|
||||
|
||||
return &canonical.CanonicalRequest{
|
||||
Model: req.Model,
|
||||
System: system,
|
||||
Messages: canonicalMsgs,
|
||||
Tools: tools,
|
||||
ToolChoice: toolChoice,
|
||||
Parameters: params,
|
||||
Thinking: thinking,
|
||||
Stream: req.Stream,
|
||||
UserID: req.User,
|
||||
OutputFormat: outputFormat,
|
||||
ParallelToolUse: parallelToolUse,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeSystemPrompt 提取 system 和 developer 消息
|
||||
func decodeSystemPrompt(messages []Message) (any, []Message) {
|
||||
var systemParts []string
|
||||
var remaining []Message
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" || msg.Role == "developer" {
|
||||
text := extractText(msg.Content)
|
||||
if text != "" {
|
||||
systemParts = append(systemParts, text)
|
||||
}
|
||||
} else {
|
||||
remaining = append(remaining, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if len(systemParts) == 0 {
|
||||
return nil, remaining
|
||||
}
|
||||
return strings.Join(systemParts, "\n\n"), remaining
|
||||
}
|
||||
|
||||
// extractText 从 content 提取文本
|
||||
func extractText(content any) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []any:
|
||||
var parts []string
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if t, ok := m["type"].(string); ok && t == "text" {
|
||||
if text, ok := m["text"].(string); ok {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "")
|
||||
case nil:
|
||||
return ""
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// decodeMessage 解码 OpenAI 消息
|
||||
func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
blocks := decodeUserContent(msg.Content)
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: blocks}}, nil
|
||||
|
||||
case "assistant":
|
||||
var blocks []canonical.ContentBlock
|
||||
// 处理 content
|
||||
if msg.Content != nil {
|
||||
switch v := msg.Content.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(v))
|
||||
}
|
||||
default:
|
||||
parts := decodeContentParts(msg.Content)
|
||||
for _, p := range parts {
|
||||
if p.Type == "text" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(p.Text))
|
||||
} else if p.Type == "refusal" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(p.Refusal))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// refusal 顶层字段
|
||||
if msg.Refusal != "" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(msg.Refusal))
|
||||
}
|
||||
// reasoning_content 非标准字段
|
||||
if msg.ReasoningContent != "" {
|
||||
blocks = append(blocks, canonical.NewThinkingBlock(msg.ReasoningContent))
|
||||
}
|
||||
// tool_calls
|
||||
for _, tc := range msg.ToolCalls {
|
||||
var input json.RawMessage
|
||||
if tc.Type == "custom" && tc.Custom != nil {
|
||||
input = json.RawMessage(fmt.Sprintf("%q", tc.Custom.Input))
|
||||
} else if tc.Function != nil {
|
||||
parsed := json.RawMessage(tc.Function.Arguments)
|
||||
if !json.Valid(parsed) {
|
||||
parsed = json.RawMessage("{}")
|
||||
}
|
||||
input = parsed
|
||||
} else {
|
||||
input = json.RawMessage("{}")
|
||||
}
|
||||
name := ""
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
} else if tc.Custom != nil {
|
||||
name = tc.Custom.Name
|
||||
}
|
||||
blocks = append(blocks, canonical.NewToolUseBlock(tc.ID, name, input))
|
||||
}
|
||||
// 已废弃 function_call
|
||||
if msg.FunctionCall != nil {
|
||||
input := json.RawMessage(msg.FunctionCall.Arguments)
|
||||
if !json.Valid(input) {
|
||||
input = json.RawMessage("{}")
|
||||
}
|
||||
blocks = append(blocks, canonical.NewToolUseBlock(generateID(), msg.FunctionCall.Name, input))
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
|
||||
|
||||
case "tool":
|
||||
content := extractText(msg.Content)
|
||||
isErr := false
|
||||
block := canonical.ContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: msg.ToolCallID,
|
||||
Content: json.RawMessage(fmt.Sprintf("%q", content)),
|
||||
IsError: &isErr,
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleTool, Content: []canonical.ContentBlock{block}}}, nil
|
||||
|
||||
case "function":
|
||||
content := extractText(msg.Content)
|
||||
isErr := false
|
||||
block := canonical.ContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: msg.Name,
|
||||
Content: json.RawMessage(fmt.Sprintf("%q", content)),
|
||||
IsError: &isErr,
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleTool, Content: []canonical.ContentBlock{block}}}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// decodeUserContent 解码用户内容
|
||||
func decodeUserContent(content any) []canonical.ContentBlock {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(v)}
|
||||
case []any:
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
t, _ := m["type"].(string)
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
blocks = append(blocks, canonical.NewTextBlock(text))
|
||||
case "image_url":
|
||||
blocks = append(blocks, canonical.ContentBlock{Type: "image"})
|
||||
case "input_audio":
|
||||
blocks = append(blocks, canonical.ContentBlock{Type: "audio"})
|
||||
case "file":
|
||||
blocks = append(blocks, canonical.ContentBlock{Type: "file"})
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(blocks) > 0 {
|
||||
return blocks
|
||||
}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
case nil:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
default:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}
|
||||
}
|
||||
}
|
||||
|
||||
// contentPart 内容部分
|
||||
type contentPart struct {
|
||||
Type string
|
||||
Text string
|
||||
Refusal string
|
||||
}
|
||||
|
||||
// decodeContentParts 解码内容部分
|
||||
func decodeContentParts(content any) []contentPart {
|
||||
parts, ok := content.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var result []contentPart
|
||||
for _, item := range parts {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
t, _ := m["type"].(string)
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
result = append(result, contentPart{Type: "text", Text: text})
|
||||
case "refusal":
|
||||
refusal, _ := m["refusal"].(string)
|
||||
result = append(result, contentPart{Type: "refusal", Refusal: refusal})
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// decodeTools 解码工具定义
|
||||
func decodeTools(tools []Tool) []canonical.CanonicalTool {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
var result []canonical.CanonicalTool
|
||||
for _, tool := range tools {
|
||||
if tool.Type == "function" && tool.Function != nil {
|
||||
result = append(result, canonical.CanonicalTool{
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
InputSchema: tool.Function.Parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// decodeToolChoice 解码工具选择
|
||||
func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
if toolChoice == nil {
|
||||
return nil
|
||||
}
|
||||
switch v := toolChoice.(type) {
|
||||
case string:
|
||||
switch v {
|
||||
case "auto":
|
||||
return canonical.NewToolChoiceAuto()
|
||||
case "none":
|
||||
return canonical.NewToolChoiceNone()
|
||||
case "required":
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
case map[string]any:
|
||||
t, _ := v["type"].(string)
|
||||
switch t {
|
||||
case "function":
|
||||
if fn, ok := v["function"].(map[string]any); ok {
|
||||
name, _ := fn["name"].(string)
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
case "custom":
|
||||
if custom, ok := v["custom"].(map[string]any); ok {
|
||||
name, _ := custom["name"].(string)
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
case "allowed_tools":
|
||||
if at, ok := v["allowed_tools"].(map[string]any); ok {
|
||||
mode, _ := at["mode"].(string)
|
||||
if mode == "required" {
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
return canonical.NewToolChoiceAuto()
|
||||
}
|
||||
return canonical.NewToolChoiceAuto()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeParameters 解码请求参数
|
||||
func decodeParameters(req *ChatCompletionRequest) canonical.RequestParameters {
|
||||
params := canonical.RequestParameters{
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
FrequencyPenalty: req.FrequencyPenalty,
|
||||
PresencePenalty: req.PresencePenalty,
|
||||
}
|
||||
if req.MaxCompletionTokens != nil {
|
||||
params.MaxTokens = req.MaxCompletionTokens
|
||||
} else if req.MaxTokens != nil {
|
||||
params.MaxTokens = req.MaxTokens
|
||||
}
|
||||
if req.Stop != nil {
|
||||
params.StopSequences = normalizeStop(req.Stop)
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// normalizeStop 规范化 stop 参数
|
||||
func normalizeStop(stop any) []string {
|
||||
switch v := stop.(type) {
|
||||
case string:
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{v}
|
||||
case []any:
|
||||
var result []string
|
||||
for _, s := range v {
|
||||
if str, ok := s.(string); ok && str != "" {
|
||||
result = append(result, str)
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
case []string:
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeOutputFormat 解码输出格式
|
||||
func decodeOutputFormat(format *ResponseFormat) *canonical.OutputFormat {
|
||||
if format == nil {
|
||||
return nil
|
||||
}
|
||||
switch format.Type {
|
||||
case "json_object":
|
||||
return &canonical.OutputFormat{Type: "json_object"}
|
||||
case "json_schema":
|
||||
if format.JSONSchema != nil {
|
||||
return &canonical.OutputFormat{
|
||||
Type: "json_schema",
|
||||
Name: format.JSONSchema.Name,
|
||||
Schema: format.JSONSchema.Schema,
|
||||
Strict: format.JSONSchema.Strict,
|
||||
}
|
||||
}
|
||||
return &canonical.OutputFormat{Type: "json_schema"}
|
||||
case "text":
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeThinking 解码推理配置
|
||||
func decodeThinking(reasoningEffort string) *canonical.ThinkingConfig {
|
||||
if reasoningEffort == "" {
|
||||
return nil
|
||||
}
|
||||
if reasoningEffort == "none" {
|
||||
return &canonical.ThinkingConfig{Type: "disabled"}
|
||||
}
|
||||
effort := reasoningEffort
|
||||
if effort == "minimal" {
|
||||
effort = "low"
|
||||
}
|
||||
return &canonical.ThinkingConfig{Type: "enabled", Effort: effort}
|
||||
}
|
||||
|
||||
// decodeDeprecatedFields 废弃字段兼容
|
||||
func decodeDeprecatedFields(req *ChatCompletionRequest) {
|
||||
if len(req.Tools) == 0 && len(req.Functions) > 0 {
|
||||
req.Tools = make([]Tool, len(req.Functions))
|
||||
for i, f := range req.Functions {
|
||||
req.Tools[i] = Tool{
|
||||
Type: "function",
|
||||
Function: &FunctionDef{
|
||||
Name: f.Name,
|
||||
Description: f.Description,
|
||||
Parameters: f.Parameters,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.ToolChoice == nil && req.FunctionCall != nil {
|
||||
switch v := req.FunctionCall.(type) {
|
||||
case string:
|
||||
switch v {
|
||||
case "none":
|
||||
req.ToolChoice = "none"
|
||||
case "auto":
|
||||
req.ToolChoice = "auto"
|
||||
}
|
||||
case map[string]any:
|
||||
if name, ok := v["name"].(string); ok {
|
||||
req.ToolChoice = map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{"name": name},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// decodeResponse 将 OpenAI 响应解码为 Canonical 响应
|
||||
func decodeResponse(body []byte) (*canonical.CanonicalResponse, error) {
|
||||
var resp ChatCompletionResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 OpenAI 响应失败").WithCause(err)
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return &canonical.CanonicalResponse{
|
||||
ID: resp.ID,
|
||||
Model: resp.Model,
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("")},
|
||||
Usage: canonical.CanonicalUsage{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
var blocks []canonical.ContentBlock
|
||||
|
||||
if choice.Message != nil {
|
||||
if choice.Message.Content != nil {
|
||||
text := extractText(choice.Message.Content)
|
||||
if text != "" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(text))
|
||||
}
|
||||
}
|
||||
if choice.Message.Refusal != "" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(choice.Message.Refusal))
|
||||
}
|
||||
if choice.Message.ReasoningContent != "" {
|
||||
blocks = append(blocks, canonical.NewThinkingBlock(choice.Message.ReasoningContent))
|
||||
}
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
var input json.RawMessage
|
||||
name := ""
|
||||
if tc.Type == "custom" && tc.Custom != nil {
|
||||
input = json.RawMessage(fmt.Sprintf("%q", tc.Custom.Input))
|
||||
name = tc.Custom.Name
|
||||
} else if tc.Function != nil {
|
||||
input = json.RawMessage(tc.Function.Arguments)
|
||||
if !json.Valid(input) {
|
||||
input = json.RawMessage("{}")
|
||||
}
|
||||
name = tc.Function.Name
|
||||
} else {
|
||||
input = json.RawMessage("{}")
|
||||
}
|
||||
blocks = append(blocks, canonical.NewToolUseBlock(tc.ID, name, input))
|
||||
}
|
||||
}
|
||||
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
|
||||
var stopReason *canonical.StopReason
|
||||
if choice.FinishReason != nil {
|
||||
sr := mapFinishReason(*choice.FinishReason)
|
||||
stopReason = &sr
|
||||
}
|
||||
|
||||
return &canonical.CanonicalResponse{
|
||||
ID: resp.ID,
|
||||
Model: resp.Model,
|
||||
Content: blocks,
|
||||
StopReason: stopReason,
|
||||
Usage: decodeUsage(resp.Usage),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// mapFinishReason 映射结束原因
|
||||
func mapFinishReason(reason string) canonical.StopReason {
|
||||
switch reason {
|
||||
case "stop":
|
||||
return canonical.StopReasonEndTurn
|
||||
case "length":
|
||||
return canonical.StopReasonMaxTokens
|
||||
case "tool_calls":
|
||||
return canonical.StopReasonToolUse
|
||||
case "function_call":
|
||||
return canonical.StopReasonToolUse
|
||||
case "content_filter":
|
||||
return canonical.StopReasonContentFilter
|
||||
default:
|
||||
return canonical.StopReasonEndTurn
|
||||
}
|
||||
}
|
||||
|
||||
// decodeUsage 解码用量
|
||||
func decodeUsage(usage *Usage) canonical.CanonicalUsage {
|
||||
if usage == nil {
|
||||
return canonical.CanonicalUsage{}
|
||||
}
|
||||
result := canonical.CanonicalUsage{
|
||||
InputTokens: usage.PromptTokens,
|
||||
OutputTokens: usage.CompletionTokens,
|
||||
}
|
||||
if usage.PromptTokensDetails != nil && usage.PromptTokensDetails.CachedTokens > 0 {
|
||||
val := usage.PromptTokensDetails.CachedTokens
|
||||
result.CacheReadTokens = &val
|
||||
}
|
||||
if usage.CompletionTokensDetails != nil && usage.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
val := usage.CompletionTokensDetails.ReasoningTokens
|
||||
result.ReasoningTokens = &val
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// decodeModelsResponse 解码模型列表响应
|
||||
func decodeModelsResponse(body []byte) (*canonical.CanonicalModelList, error) {
|
||||
var resp ModelsResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
models := make([]canonical.CanonicalModel, len(resp.Data))
|
||||
for i, m := range resp.Data {
|
||||
models[i] = canonical.CanonicalModel{
|
||||
ID: m.ID,
|
||||
Name: m.ID,
|
||||
Created: m.Created,
|
||||
OwnedBy: m.OwnedBy,
|
||||
}
|
||||
}
|
||||
return &canonical.CanonicalModelList{Models: models}, nil
|
||||
}
|
||||
|
||||
// decodeModelInfoResponse 解码模型详情响应
|
||||
func decodeModelInfoResponse(body []byte) (*canonical.CanonicalModelInfo, error) {
|
||||
var resp ModelInfoResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &canonical.CanonicalModelInfo{
|
||||
ID: resp.ID,
|
||||
Name: resp.ID,
|
||||
Created: resp.Created,
|
||||
OwnedBy: resp.OwnedBy,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeEmbeddingRequest 解码嵌入请求
|
||||
func decodeEmbeddingRequest(body []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
var req EmbeddingRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &canonical.CanonicalEmbeddingRequest{
|
||||
Model: req.Model,
|
||||
Input: req.Input,
|
||||
EncodingFormat: req.EncodingFormat,
|
||||
Dimensions: req.Dimensions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeEmbeddingResponse 解码嵌入响应
|
||||
func decodeEmbeddingResponse(body []byte) (*canonical.CanonicalEmbeddingResponse, error) {
|
||||
var resp EmbeddingResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := make([]canonical.EmbeddingData, len(resp.Data))
|
||||
for i, d := range resp.Data {
|
||||
data[i] = canonical.EmbeddingData{Index: d.Index, Embedding: d.Embedding}
|
||||
}
|
||||
return &canonical.CanonicalEmbeddingResponse{
|
||||
Data: data,
|
||||
Model: resp.Model,
|
||||
Usage: canonical.EmbeddingUsage{
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeRerankRequest 解码重排序请求
|
||||
func decodeRerankRequest(body []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
var req RerankRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &canonical.CanonicalRerankRequest{
|
||||
Model: req.Model,
|
||||
Query: req.Query,
|
||||
Documents: req.Documents,
|
||||
TopN: req.TopN,
|
||||
ReturnDocuments: req.ReturnDocuments,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeRerankResponse 解码重排序响应
|
||||
func decodeRerankResponse(body []byte) (*canonical.CanonicalRerankResponse, error) {
|
||||
var resp RerankResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results := make([]canonical.RerankResult, len(resp.Results))
|
||||
for i, r := range resp.Results {
|
||||
results[i] = canonical.RerankResult{
|
||||
Index: r.Index,
|
||||
RelevanceScore: r.RelevanceScore,
|
||||
Document: r.Document,
|
||||
}
|
||||
}
|
||||
return &canonical.CanonicalRerankResponse{Results: results, Model: resp.Model}, nil
|
||||
}
|
||||
|
||||
// generateID 生成唯一 ID
|
||||
func generateID() string {
|
||||
return fmt.Sprintf("call_%d", generateCounter())
|
||||
}
|
||||
|
||||
var idCounter int64
|
||||
|
||||
func generateCounter() int64 {
|
||||
return atomic.AddInt64(&idCounter, 1)
|
||||
}
|
||||
411
backend/internal/conversion/openai/decoder_test.go
Normal file
411
backend/internal/conversion/openai/decoder_test.go
Normal file
@@ -0,0 +1,411 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecodeRequest_BasicChat(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "你好"}
|
||||
],
|
||||
"temperature": 0.7
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", req.Model)
|
||||
assert.Len(t, req.Messages, 1)
|
||||
assert.Equal(t, canonical.RoleUser, req.Messages[0].Role)
|
||||
assert.NotNil(t, req.Parameters.Temperature)
|
||||
assert.Equal(t, 0.7, *req.Parameters.Temperature)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_SystemAndDeveloper(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "system", "content": "你是助手"},
|
||||
{"role": "developer", "content": "额外指令"},
|
||||
{"role": "user", "content": "你好"}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "你是助手\n\n额外指令", req.System)
|
||||
assert.Len(t, req.Messages, 1)
|
||||
assert.Equal(t, canonical.RoleUser, req.Messages[0].Role)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ToolCalls(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "天气"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {"name": "get_weather", "arguments": "{\"city\":\"北京\"}"}
|
||||
}]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, req.Messages, 2)
|
||||
assistantMsg := req.Messages[1]
|
||||
assert.Equal(t, canonical.RoleAssistant, assistantMsg.Role)
|
||||
found := false
|
||||
for _, b := range assistantMsg.Content {
|
||||
if b.Type == "tool_use" {
|
||||
found = true
|
||||
assert.Equal(t, "call_123", b.ID)
|
||||
assert.Equal(t, "get_weather", b.Name)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ToolMessage(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "天气"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "晴天 25°C"
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
toolMsg := req.Messages[2]
|
||||
assert.Equal(t, canonical.RoleTool, toolMsg.Role)
|
||||
assert.Equal(t, "call_1", toolMsg.Content[0].ToolUseID)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_MissingModel(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
|
||||
_, err := decodeRequest(body)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "INVALID_INPUT")
|
||||
}
|
||||
|
||||
func TestDecodeRequest_MissingMessages(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4"}`)
|
||||
_, err := decodeRequest(body)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "INVALID_INPUT")
|
||||
}
|
||||
|
||||
func TestDecodeRequest_DeprecatedFunctions(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"functions": [{
|
||||
"name": "get_weather",
|
||||
"description": "获取天气",
|
||||
"parameters": {"type":"object","properties":{"city":{"type":"string"}}}
|
||||
}]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, req.Tools, 1)
|
||||
assert.Equal(t, "get_weather", req.Tools[0].Name)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Basic(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "chatcmpl-123",
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "你好"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "chatcmpl-123", resp.ID)
|
||||
assert.Equal(t, "gpt-4", resp.Model)
|
||||
assert.Len(t, resp.Content, 1)
|
||||
assert.Equal(t, "你好", resp.Content[0].Text)
|
||||
assert.NotNil(t, resp.StopReason)
|
||||
assert.Equal(t, canonical.StopReasonEndTurn, *resp.StopReason)
|
||||
assert.Equal(t, 10, resp.Usage.InputTokens)
|
||||
assert.Equal(t, 5, resp.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_ToolCalls(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "chatcmpl-456",
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_abc",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{\"q\":\"test\"}"}
|
||||
}]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}]
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
found := false
|
||||
for _, b := range resp.Content {
|
||||
if b.Type == "tool_use" {
|
||||
found = true
|
||||
assert.Equal(t, "call_abc", b.ID)
|
||||
assert.Equal(t, "search", b.Name)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, canonical.StopReasonToolUse, *resp.StopReason)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Thinking(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "chatcmpl-789",
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "回答",
|
||||
"reasoning_content": "思考过程"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Content, 2)
|
||||
assert.Equal(t, "回答", resp.Content[0].Text)
|
||||
assert.Equal(t, "thinking", resp.Content[1].Type)
|
||||
assert.Equal(t, "思考过程", resp.Content[1].Thinking)
|
||||
}
|
||||
|
||||
func TestDecodeModelsResponse(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{"id": "gpt-4", "object": "model", "created": 1700000000, "owned_by": "openai"},
|
||||
{"id": "gpt-3.5-turbo", "object": "model", "created": 1700000001, "owned_by": "openai"}
|
||||
]
|
||||
}`)
|
||||
|
||||
list, err := decodeModelsResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, list.Models, 2)
|
||||
assert.Equal(t, "gpt-4", list.Models[0].ID)
|
||||
assert.Equal(t, "gpt-3.5-turbo", list.Models[1].ID)
|
||||
assert.Equal(t, int64(1700000000), list.Models[0].Created)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_InvalidJSON(t *testing.T) {
|
||||
_, err := decodeRequest([]byte(`invalid json`))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "JSON_PARSE_ERROR")
|
||||
}
|
||||
|
||||
func TestDecodeRequest_Parameters(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"temperature": 0.5,
|
||||
"max_completion_tokens": 2048,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0.1,
|
||||
"presence_penalty": 0.2,
|
||||
"stop": ["STOP"]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, req.Parameters.Temperature)
|
||||
assert.Equal(t, 0.5, *req.Parameters.Temperature)
|
||||
assert.NotNil(t, req.Parameters.MaxTokens)
|
||||
assert.Equal(t, 2048, *req.Parameters.MaxTokens)
|
||||
assert.NotNil(t, req.Parameters.TopP)
|
||||
assert.Equal(t, 0.9, *req.Parameters.TopP)
|
||||
assert.NotNil(t, req.Parameters.FrequencyPenalty)
|
||||
assert.Equal(t, 0.1, *req.Parameters.FrequencyPenalty)
|
||||
assert.NotNil(t, req.Parameters.PresencePenalty)
|
||||
assert.Equal(t, 0.2, *req.Parameters.PresencePenalty)
|
||||
assert.Equal(t, []string{"STOP"}, req.Parameters.StopSequences)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ToolChoice(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
jsonBody string
|
||||
want *canonical.ToolChoice
|
||||
}{
|
||||
{
|
||||
name: "auto",
|
||||
jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":"auto"}`,
|
||||
want: canonical.NewToolChoiceAuto(),
|
||||
},
|
||||
{
|
||||
name: "none",
|
||||
jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":"none"}`,
|
||||
want: canonical.NewToolChoiceNone(),
|
||||
},
|
||||
{
|
||||
name: "required",
|
||||
jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":"required"}`,
|
||||
want: canonical.NewToolChoiceAny(),
|
||||
},
|
||||
{
|
||||
name: "named",
|
||||
jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":{"type":"function","function":{"name":"x"}}}`,
|
||||
want: canonical.NewToolChoiceNamed("x"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := decodeRequest([]byte(tt.jsonBody))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.ToolChoice)
|
||||
assert.Equal(t, tt.want.Type, req.ToolChoice.Type)
|
||||
assert.Equal(t, tt.want.Name, req.ToolChoice.Name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeRequest_OutputFormat_JSONSchema(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"response_format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "my_schema",
|
||||
"schema": {"type":"object","properties":{"name":{"type":"string"}}},
|
||||
"strict": true
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.OutputFormat)
|
||||
assert.Equal(t, "json_schema", req.OutputFormat.Type)
|
||||
assert.Equal(t, "my_schema", req.OutputFormat.Name)
|
||||
assert.NotNil(t, req.OutputFormat.Schema)
|
||||
require.NotNil(t, req.OutputFormat.Strict)
|
||||
assert.True(t, *req.OutputFormat.Strict)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_OutputFormat_JSON(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"response_format": {"type": "json_object"}
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.OutputFormat)
|
||||
assert.Equal(t, "json_object", req.OutputFormat.Type)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_StopReasons(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
finishReason string
|
||||
want canonical.StopReason
|
||||
}{
|
||||
{"stop→end_turn", "stop", canonical.StopReasonEndTurn},
|
||||
{"length→max_tokens", "length", canonical.StopReasonMaxTokens},
|
||||
{"tool_calls→tool_use", "tool_calls", canonical.StopReasonToolUse},
|
||||
{"content_filter→content_filter", "content_filter", canonical.StopReasonContentFilter},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body := []byte(fmt.Sprintf(`{
|
||||
"id": "resp-1",
|
||||
"model": "gpt-4",
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "ok"}, "finish_reason": "%s"}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}
|
||||
}`, tt.finishReason))
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.StopReason)
|
||||
assert.Equal(t, tt.want, *resp.StopReason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Usage(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "resp-1",
|
||||
"model": "gpt-4",
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "ok"}, "finish_reason": "stop"}],
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
"prompt_tokens_details": {"cached_tokens": 80}
|
||||
}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, resp.Usage.InputTokens)
|
||||
assert.Equal(t, 50, resp.Usage.OutputTokens)
|
||||
require.NotNil(t, resp.Usage.CacheReadTokens)
|
||||
assert.Equal(t, 80, *resp.Usage.CacheReadTokens)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Refusal(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "resp-1",
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": null, "refusal": "我拒绝回答"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
found := false
|
||||
for _, b := range resp.Content {
|
||||
if b.Text == "我拒绝回答" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
532
backend/internal/conversion/openai/encoder.go
Normal file
532
backend/internal/conversion/openai/encoder.go
Normal file
@@ -0,0 +1,532 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// encodeRequest 将 Canonical 请求编码为 OpenAI 请求
|
||||
func encodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
result := map[string]any{
|
||||
"model": provider.ModelName,
|
||||
"stream": req.Stream,
|
||||
}
|
||||
|
||||
// 系统消息 + 消息
|
||||
messages := encodeSystemAndMessages(req)
|
||||
result["messages"] = messages
|
||||
|
||||
// 参数
|
||||
encodeParametersInto(req, result)
|
||||
|
||||
// 工具
|
||||
if len(req.Tools) > 0 {
|
||||
tools := make([]map[string]any, len(req.Tools))
|
||||
for i, t := range req.Tools {
|
||||
tools[i] = map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": t.Name,
|
||||
"description": t.Description,
|
||||
"parameters": t.InputSchema,
|
||||
},
|
||||
}
|
||||
}
|
||||
result["tools"] = tools
|
||||
}
|
||||
if req.ToolChoice != nil {
|
||||
result["tool_choice"] = encodeToolChoice(req.ToolChoice)
|
||||
}
|
||||
|
||||
// 公共字段
|
||||
if req.UserID != "" {
|
||||
result["user"] = req.UserID
|
||||
}
|
||||
if req.OutputFormat != nil {
|
||||
result["response_format"] = encodeOutputFormat(req.OutputFormat)
|
||||
}
|
||||
if req.ParallelToolUse != nil {
|
||||
result["parallel_tool_calls"] = *req.ParallelToolUse
|
||||
}
|
||||
if req.Thinking != nil {
|
||||
switch req.Thinking.Type {
|
||||
case "disabled":
|
||||
result["reasoning_effort"] = "none"
|
||||
default:
|
||||
if req.Thinking.Effort != "" {
|
||||
result["reasoning_effort"] = req.Thinking.Effort
|
||||
} else {
|
||||
result["reasoning_effort"] = "medium"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 OpenAI 请求失败").WithCause(err)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// encodeSystemAndMessages 编码系统消息和消息列表
|
||||
func encodeSystemAndMessages(req *canonical.CanonicalRequest) []map[string]any {
|
||||
var messages []map[string]any
|
||||
|
||||
// 系统消息
|
||||
switch v := req.System.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "system",
|
||||
"content": v,
|
||||
})
|
||||
}
|
||||
case []canonical.SystemBlock:
|
||||
var parts []string
|
||||
for _, b := range v {
|
||||
parts = append(parts, b.Text)
|
||||
}
|
||||
text := joinStrings(parts, "\n\n")
|
||||
if text != "" {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "system",
|
||||
"content": text,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 消息
|
||||
for _, msg := range req.Messages {
|
||||
encoded := encodeMessage(msg)
|
||||
messages = append(messages, encoded...)
|
||||
}
|
||||
|
||||
// 合并连续同角色消息
|
||||
return mergeConsecutiveRoles(messages)
|
||||
}
|
||||
|
||||
// encodeMessage 编码单条消息
|
||||
func encodeMessage(msg canonical.CanonicalMessage) []map[string]any {
|
||||
switch msg.Role {
|
||||
case canonical.RoleUser:
|
||||
return []map[string]any{{
|
||||
"role": "user",
|
||||
"content": encodeUserContent(msg.Content),
|
||||
}}
|
||||
case canonical.RoleAssistant:
|
||||
m := map[string]any{"role": "assistant"}
|
||||
var textParts []string
|
||||
var toolUses []canonical.ContentBlock
|
||||
|
||||
for _, b := range msg.Content {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
textParts = append(textParts, b.Text)
|
||||
case "tool_use":
|
||||
toolUses = append(toolUses, b)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolUses) > 0 {
|
||||
if len(textParts) > 0 {
|
||||
m["content"] = joinStrings(textParts, "")
|
||||
} else {
|
||||
m["content"] = nil
|
||||
}
|
||||
tcs := make([]map[string]any, len(toolUses))
|
||||
for i, tu := range toolUses {
|
||||
tcs[i] = map[string]any{
|
||||
"id": tu.ID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": tu.Name,
|
||||
"arguments": string(tu.Input),
|
||||
},
|
||||
}
|
||||
}
|
||||
m["tool_calls"] = tcs
|
||||
} else if len(textParts) > 0 {
|
||||
m["content"] = joinStrings(textParts, "")
|
||||
} else {
|
||||
m["content"] = ""
|
||||
}
|
||||
return []map[string]any{m}
|
||||
|
||||
case canonical.RoleTool:
|
||||
for _, b := range msg.Content {
|
||||
if b.Type == "tool_result" {
|
||||
var contentStr string
|
||||
if b.Content != nil {
|
||||
var s string
|
||||
if json.Unmarshal(b.Content, &s) == nil {
|
||||
contentStr = s
|
||||
} else {
|
||||
contentStr = string(b.Content)
|
||||
}
|
||||
}
|
||||
return []map[string]any{{
|
||||
"role": "tool",
|
||||
"tool_call_id": b.ToolUseID,
|
||||
"content": contentStr,
|
||||
}}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeUserContent 编码用户内容
|
||||
func encodeUserContent(blocks []canonical.ContentBlock) any {
|
||||
if len(blocks) == 1 && blocks[0].Type == "text" {
|
||||
return blocks[0].Text
|
||||
}
|
||||
parts := make([]map[string]any, 0, len(blocks))
|
||||
for _, b := range blocks {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
parts = append(parts, map[string]any{"type": "text", "text": b.Text})
|
||||
case "image":
|
||||
parts = append(parts, map[string]any{"type": "image_url"})
|
||||
case "audio":
|
||||
parts = append(parts, map[string]any{"type": "input_audio"})
|
||||
case "file":
|
||||
parts = append(parts, map[string]any{"type": "file"})
|
||||
}
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// encodeToolChoice 编码工具选择
|
||||
func encodeToolChoice(choice *canonical.ToolChoice) any {
|
||||
switch choice.Type {
|
||||
case "auto":
|
||||
return "auto"
|
||||
case "none":
|
||||
return "none"
|
||||
case "any":
|
||||
return "required"
|
||||
case "tool":
|
||||
return map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": choice.Name,
|
||||
},
|
||||
}
|
||||
}
|
||||
return "auto"
|
||||
}
|
||||
|
||||
// encodeParametersInto 编码参数到结果 map
|
||||
func encodeParametersInto(req *canonical.CanonicalRequest, result map[string]any) {
|
||||
if req.Parameters.MaxTokens != nil {
|
||||
result["max_completion_tokens"] = *req.Parameters.MaxTokens
|
||||
}
|
||||
if req.Parameters.Temperature != nil {
|
||||
result["temperature"] = *req.Parameters.Temperature
|
||||
}
|
||||
if req.Parameters.TopP != nil {
|
||||
result["top_p"] = *req.Parameters.TopP
|
||||
}
|
||||
if req.Parameters.FrequencyPenalty != nil {
|
||||
result["frequency_penalty"] = *req.Parameters.FrequencyPenalty
|
||||
}
|
||||
if req.Parameters.PresencePenalty != nil {
|
||||
result["presence_penalty"] = *req.Parameters.PresencePenalty
|
||||
}
|
||||
if len(req.Parameters.StopSequences) > 0 {
|
||||
result["stop"] = req.Parameters.StopSequences
|
||||
}
|
||||
}
|
||||
|
||||
// encodeOutputFormat 编码输出格式
|
||||
func encodeOutputFormat(format *canonical.OutputFormat) map[string]any {
|
||||
switch format.Type {
|
||||
case "json_object":
|
||||
return map[string]any{"type": "json_object"}
|
||||
case "json_schema":
|
||||
m := map[string]any{"type": "json_schema"}
|
||||
schema := map[string]any{
|
||||
"name": format.Name,
|
||||
}
|
||||
if format.Schema != nil {
|
||||
schema["schema"] = format.Schema
|
||||
}
|
||||
if format.Strict != nil {
|
||||
schema["strict"] = *format.Strict
|
||||
}
|
||||
m["json_schema"] = schema
|
||||
return m
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeResponse 将 Canonical 响应编码为 OpenAI 响应
|
||||
func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
var textParts []string
|
||||
var thinkingParts []string
|
||||
var toolUses []canonical.ContentBlock
|
||||
|
||||
for _, b := range resp.Content {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
textParts = append(textParts, b.Text)
|
||||
case "thinking":
|
||||
thinkingParts = append(thinkingParts, b.Thinking)
|
||||
case "tool_use":
|
||||
toolUses = append(toolUses, b)
|
||||
}
|
||||
}
|
||||
|
||||
message := map[string]any{"role": "assistant"}
|
||||
if len(toolUses) > 0 {
|
||||
if len(textParts) > 0 {
|
||||
message["content"] = joinStrings(textParts, "")
|
||||
} else {
|
||||
message["content"] = nil
|
||||
}
|
||||
tcs := make([]map[string]any, len(toolUses))
|
||||
for i, tu := range toolUses {
|
||||
tcs[i] = map[string]any{
|
||||
"id": tu.ID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": tu.Name,
|
||||
"arguments": string(tu.Input),
|
||||
},
|
||||
}
|
||||
}
|
||||
message["tool_calls"] = tcs
|
||||
} else if len(textParts) > 0 {
|
||||
message["content"] = joinStrings(textParts, "")
|
||||
} else {
|
||||
message["content"] = ""
|
||||
}
|
||||
|
||||
if len(thinkingParts) > 0 {
|
||||
message["reasoning_content"] = joinStrings(thinkingParts, "")
|
||||
}
|
||||
|
||||
var finishReason *string
|
||||
if resp.StopReason != nil {
|
||||
fr := mapCanonicalToFinishReason(*resp.StopReason)
|
||||
finishReason = &fr
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"id": resp.ID,
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": resp.Model,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"finish_reason": finishReason,
|
||||
}},
|
||||
"usage": encodeUsage(resp.Usage),
|
||||
}
|
||||
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 OpenAI 响应失败").WithCause(err)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// mapCanonicalToFinishReason 映射 Canonical 停止原因到 OpenAI finish_reason
|
||||
func mapCanonicalToFinishReason(reason canonical.StopReason) string {
|
||||
switch reason {
|
||||
case canonical.StopReasonEndTurn:
|
||||
return "stop"
|
||||
case canonical.StopReasonMaxTokens:
|
||||
return "length"
|
||||
case canonical.StopReasonToolUse:
|
||||
return "tool_calls"
|
||||
case canonical.StopReasonContentFilter:
|
||||
return "content_filter"
|
||||
case canonical.StopReasonStopSequence:
|
||||
return "stop"
|
||||
case canonical.StopReasonRefusal:
|
||||
return "stop"
|
||||
default:
|
||||
return "stop"
|
||||
}
|
||||
}
|
||||
|
||||
// encodeUsage 编码用量
|
||||
func encodeUsage(usage canonical.CanonicalUsage) map[string]any {
|
||||
result := map[string]any{
|
||||
"prompt_tokens": usage.InputTokens,
|
||||
"completion_tokens": usage.OutputTokens,
|
||||
"total_tokens": usage.InputTokens + usage.OutputTokens,
|
||||
}
|
||||
if usage.CacheReadTokens != nil && *usage.CacheReadTokens > 0 {
|
||||
result["prompt_tokens_details"] = map[string]any{
|
||||
"cached_tokens": *usage.CacheReadTokens,
|
||||
}
|
||||
}
|
||||
if usage.ReasoningTokens != nil && *usage.ReasoningTokens > 0 {
|
||||
result["completion_tokens_details"] = map[string]any{
|
||||
"reasoning_tokens": *usage.ReasoningTokens,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// encodeModelsResponse 编码模型列表响应
|
||||
func encodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
|
||||
data := make([]map[string]any, len(list.Models))
|
||||
for i, m := range list.Models {
|
||||
created := int64(0)
|
||||
if m.Created != 0 {
|
||||
created = m.Created
|
||||
}
|
||||
ownedBy := "unknown"
|
||||
if m.OwnedBy != "" {
|
||||
ownedBy = m.OwnedBy
|
||||
}
|
||||
data[i] = map[string]any{
|
||||
"id": m.ID,
|
||||
"object": "model",
|
||||
"created": created,
|
||||
"owned_by": ownedBy,
|
||||
}
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
// encodeModelInfoResponse 编码模型详情响应
|
||||
func encodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
|
||||
created := int64(0)
|
||||
if info.Created != 0 {
|
||||
created = info.Created
|
||||
}
|
||||
ownedBy := "unknown"
|
||||
if info.OwnedBy != "" {
|
||||
ownedBy = info.OwnedBy
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"id": info.ID,
|
||||
"object": "model",
|
||||
"created": created,
|
||||
"owned_by": ownedBy,
|
||||
})
|
||||
}
|
||||
|
||||
// encodeEmbeddingRequest 编码嵌入请求
|
||||
func encodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
result := map[string]any{
|
||||
"model": provider.ModelName,
|
||||
"input": req.Input,
|
||||
}
|
||||
if req.EncodingFormat != "" {
|
||||
result["encoding_format"] = req.EncodingFormat
|
||||
}
|
||||
if req.Dimensions != nil {
|
||||
result["dimensions"] = *req.Dimensions
|
||||
}
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// encodeEmbeddingResponse 编码嵌入响应
|
||||
func encodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) {
|
||||
data := make([]map[string]any, len(resp.Data))
|
||||
for i, d := range resp.Data {
|
||||
data[i] = map[string]any{
|
||||
"index": d.Index,
|
||||
"embedding": d.Embedding,
|
||||
}
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": resp.Model,
|
||||
"usage": resp.Usage,
|
||||
})
|
||||
}
|
||||
|
||||
// encodeRerankRequest 编码重排序请求
|
||||
func encodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
result := map[string]any{
|
||||
"model": provider.ModelName,
|
||||
"query": req.Query,
|
||||
"documents": req.Documents,
|
||||
}
|
||||
if req.TopN != nil {
|
||||
result["top_n"] = *req.TopN
|
||||
}
|
||||
if req.ReturnDocuments != nil {
|
||||
result["return_documents"] = *req.ReturnDocuments
|
||||
}
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// encodeRerankResponse 编码重排序响应
|
||||
func encodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
results := make([]map[string]any, len(resp.Results))
|
||||
for i, r := range resp.Results {
|
||||
m := map[string]any{
|
||||
"index": r.Index,
|
||||
"relevance_score": r.RelevanceScore,
|
||||
}
|
||||
if r.Document != nil {
|
||||
m["document"] = *r.Document
|
||||
}
|
||||
results[i] = m
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"results": results,
|
||||
"model": resp.Model,
|
||||
})
|
||||
}
|
||||
|
||||
// joinStrings 拼接字符串切片
|
||||
func joinStrings(parts []string, sep string) string {
|
||||
result := ""
|
||||
for i, p := range parts {
|
||||
if i > 0 {
|
||||
result += sep
|
||||
}
|
||||
result += p
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// mergeConsecutiveRoles 合并连续同角色消息(拼接内容)
|
||||
func mergeConsecutiveRoles(messages []map[string]any) []map[string]any {
|
||||
if len(messages) <= 1 {
|
||||
return messages
|
||||
}
|
||||
var result []map[string]any
|
||||
for _, msg := range messages {
|
||||
if len(result) > 0 {
|
||||
lastRole := result[len(result)-1]["role"]
|
||||
currRole := msg["role"]
|
||||
if lastRole == currRole {
|
||||
lastContent := result[len(result)-1]["content"]
|
||||
currContent := msg["content"]
|
||||
switch lv := lastContent.(type) {
|
||||
case string:
|
||||
if cv, ok := currContent.(string); ok {
|
||||
result[len(result)-1]["content"] = lv + cv
|
||||
}
|
||||
case []any:
|
||||
if cv, ok := currContent.([]any); ok {
|
||||
result[len(result)-1]["content"] = append(lv, cv...)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
result = append(result, msg)
|
||||
}
|
||||
return result
|
||||
}
|
||||
355
backend/internal/conversion/openai/encoder_test.go
Normal file
355
backend/internal/conversion/openai/encoder_test.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncodeRequest_Basic(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Stream: true,
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "my-model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "my-model", result["model"])
|
||||
assert.Equal(t, true, result["stream"])
|
||||
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 1)
|
||||
}
|
||||
|
||||
func TestEncodeRequest_SystemInjection(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
System: "你是助手",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
assert.Len(t, msgs, 2)
|
||||
firstMsg := msgs[0].(map[string]any)
|
||||
assert.Equal(t, "system", firstMsg["role"])
|
||||
assert.Equal(t, "你是助手", firstMsg["content"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolCalls(t *testing.T) {
|
||||
input := json.RawMessage(`{"city":"北京"}`)
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{
|
||||
{
|
||||
Role: canonical.RoleAssistant,
|
||||
Content: []canonical.ContentBlock{
|
||||
canonical.NewToolUseBlock("call_1", "get_weather", input),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
assistantMsg := msgs[0].(map[string]any)
|
||||
toolCalls, ok := assistantMsg["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, toolCalls, 1)
|
||||
tc := toolCalls[0].(map[string]any)
|
||||
assert.Equal(t, "call_1", tc["id"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_Thinking(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Thinking: &canonical.ThinkingConfig{Type: "enabled", Effort: "high"},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "high", result["reasoning_effort"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_Basic(t *testing.T) {
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "resp-1", result["id"])
|
||||
assert.Equal(t, "chat.completion", result["object"])
|
||||
|
||||
choices := result["choices"].([]any)
|
||||
choice := choices[0].(map[string]any)
|
||||
msg := choice["message"].(map[string]any)
|
||||
assert.Equal(t, "你好", msg["content"])
|
||||
assert.Equal(t, "stop", choice["finish_reason"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
sr := canonical.StopReasonToolUse
|
||||
input := json.RawMessage(`{"q":"test"}`)
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "resp-2",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
msg := choices[0].(map[string]any)["message"].(map[string]any)
|
||||
tcs, ok := msg["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, tcs, 1)
|
||||
}
|
||||
|
||||
func TestEncodeModelsResponse(t *testing.T) {
|
||||
list := &canonical.CanonicalModelList{
|
||||
Models: []canonical.CanonicalModel{
|
||||
{ID: "gpt-4", Created: 1700000000, OwnedBy: "openai"},
|
||||
{ID: "gpt-3.5-turbo", Created: 1700000001, OwnedBy: "openai"},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := encodeModelsResponse(list)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "list", result["object"])
|
||||
data := result["data"].([]any)
|
||||
assert.Len(t, data, 2)
|
||||
}
|
||||
|
||||
func TestMergeConsecutiveRoles(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{"role": "user", "content": "A"},
|
||||
{"role": "user", "content": "B"},
|
||||
{"role": "assistant", "content": "C"},
|
||||
{"role": "assistant", "content": "D"},
|
||||
}
|
||||
|
||||
result := mergeConsecutiveRoles(messages)
|
||||
assert.Len(t, result, 2)
|
||||
assert.Equal(t, "AB", result[0]["content"])
|
||||
assert.Equal(t, "CD", result[1]["content"])
|
||||
}
|
||||
|
||||
func TestMergeConsecutiveRoles_NotOverwriting(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{"role": "user", "content": "你好"},
|
||||
{"role": "user", "content": "世界"},
|
||||
}
|
||||
|
||||
result := mergeConsecutiveRoles(messages)
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, "你好世界", result[0]["content"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolChoice_Auto(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
ToolChoice: canonical.NewToolChoiceAuto(),
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "auto", result["tool_choice"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolChoice_None(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
ToolChoice: canonical.NewToolChoiceNone(),
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "none", result["tool_choice"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolChoice_Required(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
ToolChoice: canonical.NewToolChoiceAny(),
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "required", result["tool_choice"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolChoice_Named(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
ToolChoice: canonical.NewToolChoiceNamed("my_func"),
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
tc, ok := result["tool_choice"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "function", tc["type"])
|
||||
fn, ok := tc["function"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "my_func", fn["name"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_OutputFormat_JSONSchema(t *testing.T) {
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`)
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
OutputFormat: &canonical.OutputFormat{
|
||||
Type: "json_schema",
|
||||
Name: "my_schema",
|
||||
Schema: schema,
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
rf, ok := result["response_format"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "json_schema", rf["type"])
|
||||
js, ok := rf["json_schema"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "my_schema", js["name"])
|
||||
assert.NotNil(t, js["schema"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_OutputFormat_Text(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
_, hasResponseFormat := result["response_format"]
|
||||
assert.False(t, hasResponseFormat)
|
||||
}
|
||||
|
||||
func TestEncodeResponse_Thinking(t *testing.T) {
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "resp-thinking",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{
|
||||
canonical.NewTextBlock("回答"),
|
||||
canonical.NewThinkingBlock("思考过程"),
|
||||
},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
msg := choices[0].(map[string]any)["message"].(map[string]any)
|
||||
assert.Equal(t, "回答", msg["content"])
|
||||
assert.Equal(t, "思考过程", msg["reasoning_content"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_Parameters(t *testing.T) {
|
||||
temp := 0.5
|
||||
maxTokens := 2048
|
||||
topP := 0.9
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Parameters: canonical.RequestParameters{
|
||||
Temperature: &temp,
|
||||
MaxTokens: &maxTokens,
|
||||
TopP: &topP,
|
||||
StopSequences: []string{"STOP", "END"},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, temp, result["temperature"])
|
||||
assert.Equal(t, float64(maxTokens), result["max_completion_tokens"])
|
||||
assert.Equal(t, topP, result["top_p"])
|
||||
stop, ok := result["stop"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, stop, 2)
|
||||
assert.Equal(t, "STOP", stop[0])
|
||||
assert.Equal(t, "END", stop[1])
|
||||
}
|
||||
230
backend/internal/conversion/openai/stream_decoder.go
Normal file
230
backend/internal/conversion/openai/stream_decoder.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamDecoder OpenAI 流式解码器
|
||||
type StreamDecoder struct {
|
||||
messageStarted bool
|
||||
openBlocks map[int]string
|
||||
textBlockIndex int
|
||||
thinkingBlockIndex int
|
||||
refusalBlockIndex int
|
||||
toolCallIDMap map[int]string
|
||||
toolCallNameMap map[int]string
|
||||
nextToolCallIdx int
|
||||
utf8Remainder []byte
|
||||
accumulatedUsage *canonical.CanonicalUsage
|
||||
}
|
||||
|
||||
// NewStreamDecoder 创建 OpenAI 流式解码器
|
||||
func NewStreamDecoder() *StreamDecoder {
|
||||
return &StreamDecoder{
|
||||
openBlocks: make(map[int]string),
|
||||
toolCallIDMap: make(map[int]string),
|
||||
toolCallNameMap: make(map[int]string),
|
||||
textBlockIndex: -1,
|
||||
thinkingBlockIndex: -1,
|
||||
refusalBlockIndex: -1,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 处理原始 SSE chunk
|
||||
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
// 处理 UTF-8 残余
|
||||
data := rawChunk
|
||||
if len(d.utf8Remainder) > 0 {
|
||||
data = append(d.utf8Remainder, rawChunk...)
|
||||
d.utf8Remainder = nil
|
||||
}
|
||||
|
||||
var events []canonical.CanonicalStreamEvent
|
||||
|
||||
// 解析 SSE data 行
|
||||
lines := strings.Split(string(data), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
if payload == "[DONE]" {
|
||||
events = append(events, d.flushOpenBlocks()...)
|
||||
return events
|
||||
}
|
||||
|
||||
chunkEvents := d.processDataChunk([]byte(payload))
|
||||
events = append(events, chunkEvents...)
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// Flush 刷新解码器状态
|
||||
func (d *StreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||
return nil
|
||||
}
|
||||
|
||||
// processDataChunk 处理单个 data chunk
|
||||
func (d *StreamDecoder) processDataChunk(data []byte) []canonical.CanonicalStreamEvent {
|
||||
// 检查 UTF-8 完整性
|
||||
if !utf8.Valid(data) {
|
||||
validEnd := len(data)
|
||||
for !utf8.Valid(data[:validEnd]) {
|
||||
validEnd--
|
||||
}
|
||||
d.utf8Remainder = append(d.utf8Remainder, data[validEnd:]...)
|
||||
data = data[:validEnd]
|
||||
}
|
||||
|
||||
var chunk StreamChunk
|
||||
if err := json.Unmarshal(data, &chunk); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var events []canonical.CanonicalStreamEvent
|
||||
|
||||
// 首个 chunk: MessageStart
|
||||
if !d.messageStarted {
|
||||
events = append(events, canonical.NewMessageStartEvent(chunk.ID, chunk.Model))
|
||||
d.messageStarted = true
|
||||
}
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta == nil {
|
||||
continue
|
||||
}
|
||||
delta := choice.Delta
|
||||
|
||||
// text content
|
||||
if delta.Content != nil {
|
||||
text := ""
|
||||
switch v := delta.Content.(type) {
|
||||
case string:
|
||||
text = v
|
||||
default:
|
||||
text = fmt.Sprintf("%v", v)
|
||||
}
|
||||
if text != "" {
|
||||
if _, ok := d.openBlocks[d.textBlockIndex]; !ok || d.textBlockIndex < 0 {
|
||||
d.textBlockIndex = d.allocateBlockIndex()
|
||||
d.openBlocks[d.textBlockIndex] = "text"
|
||||
events = append(events, canonical.NewContentBlockStartEvent(d.textBlockIndex,
|
||||
canonical.StreamContentBlock{Type: "text", Text: ""}))
|
||||
}
|
||||
events = append(events, canonical.NewContentBlockDeltaEvent(d.textBlockIndex,
|
||||
canonical.StreamDelta{Type: string(canonical.DeltaTypeText), Text: text}))
|
||||
}
|
||||
}
|
||||
|
||||
// reasoning_content (非标准)
|
||||
if delta.ReasoningContent != "" {
|
||||
if _, ok := d.openBlocks[d.thinkingBlockIndex]; !ok || d.thinkingBlockIndex < 0 {
|
||||
d.thinkingBlockIndex = d.allocateBlockIndex()
|
||||
d.openBlocks[d.thinkingBlockIndex] = "thinking"
|
||||
events = append(events, canonical.NewContentBlockStartEvent(d.thinkingBlockIndex,
|
||||
canonical.StreamContentBlock{Type: "thinking", Thinking: ""}))
|
||||
}
|
||||
events = append(events, canonical.NewContentBlockDeltaEvent(d.thinkingBlockIndex,
|
||||
canonical.StreamDelta{Type: string(canonical.DeltaTypeThinking), Thinking: delta.ReasoningContent}))
|
||||
}
|
||||
|
||||
// refusal
|
||||
if delta.Refusal != "" {
|
||||
if _, ok := d.openBlocks[d.refusalBlockIndex]; !ok || d.refusalBlockIndex < 0 {
|
||||
d.refusalBlockIndex = d.allocateBlockIndex()
|
||||
d.openBlocks[d.refusalBlockIndex] = "text"
|
||||
events = append(events, canonical.NewContentBlockStartEvent(d.refusalBlockIndex,
|
||||
canonical.StreamContentBlock{Type: "text", Text: ""}))
|
||||
}
|
||||
events = append(events, canonical.NewContentBlockDeltaEvent(d.refusalBlockIndex,
|
||||
canonical.StreamDelta{Type: string(canonical.DeltaTypeText), Text: delta.Refusal}))
|
||||
}
|
||||
|
||||
// tool_calls
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
for _, tc := range delta.ToolCalls {
|
||||
tcIdx := 0
|
||||
if tc.Index != nil {
|
||||
tcIdx = *tc.Index
|
||||
}
|
||||
|
||||
if tc.ID != "" {
|
||||
// 新 tool call block
|
||||
d.toolCallIDMap[tcIdx] = tc.ID
|
||||
if tc.Function != nil {
|
||||
d.toolCallNameMap[tcIdx] = tc.Function.Name
|
||||
}
|
||||
blockIdx := d.allocateBlockIndex()
|
||||
d.openBlocks[blockIdx] = fmt.Sprintf("tool_use_%d", tcIdx)
|
||||
name := d.toolCallNameMap[tcIdx]
|
||||
events = append(events, canonical.NewContentBlockStartEvent(blockIdx,
|
||||
canonical.StreamContentBlock{Type: "tool_use", ID: tc.ID, Name: name}))
|
||||
}
|
||||
|
||||
// 查找该 tool call 的 block index
|
||||
blockIdx := d.findToolUseBlockIndex(tcIdx)
|
||||
if tc.Function != nil && tc.Function.Arguments != "" {
|
||||
events = append(events, canonical.NewContentBlockDeltaEvent(blockIdx,
|
||||
canonical.StreamDelta{Type: string(canonical.DeltaTypeInputJSON), PartialJSON: tc.Function.Arguments}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// finish_reason
|
||||
if choice.FinishReason != nil && *choice.FinishReason != "" {
|
||||
events = append(events, d.flushOpenBlocks()...)
|
||||
sr := mapFinishReason(*choice.FinishReason)
|
||||
events = append(events, canonical.NewMessageDeltaEventWithUsage(sr, nil))
|
||||
events = append(events, canonical.NewMessageStopEvent())
|
||||
}
|
||||
}
|
||||
|
||||
// usage chunk (choices 为空)
|
||||
if len(chunk.Choices) == 0 && chunk.Usage != nil {
|
||||
usage := decodeUsage(chunk.Usage)
|
||||
d.accumulatedUsage = &usage
|
||||
events = append(events, canonical.NewMessageDeltaEventWithUsage(canonical.StopReasonEndTurn, &usage))
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// allocateBlockIndex 分配 block 索引
|
||||
func (d *StreamDecoder) allocateBlockIndex() int {
|
||||
maxIdx := -1
|
||||
for k := range d.openBlocks {
|
||||
if k > maxIdx {
|
||||
maxIdx = k
|
||||
}
|
||||
}
|
||||
return maxIdx + 1
|
||||
}
|
||||
|
||||
// findToolUseBlockIndex 查找 tool use block 索引
|
||||
func (d *StreamDecoder) findToolUseBlockIndex(tcIdx int) int {
|
||||
key := fmt.Sprintf("tool_use_%d", tcIdx)
|
||||
for blockIdx, typ := range d.openBlocks {
|
||||
if typ == key {
|
||||
return blockIdx
|
||||
}
|
||||
}
|
||||
return d.allocateBlockIndex()
|
||||
}
|
||||
|
||||
// flushOpenBlocks 关闭所有 open blocks
|
||||
func (d *StreamDecoder) flushOpenBlocks() []canonical.CanonicalStreamEvent {
|
||||
var events []canonical.CanonicalStreamEvent
|
||||
for idx := range d.openBlocks {
|
||||
events = append(events, canonical.NewContentBlockStopEvent(idx))
|
||||
}
|
||||
d.openBlocks = make(map[int]string)
|
||||
return events
|
||||
}
|
||||
355
backend/internal/conversion/openai/stream_decoder_test.go
Normal file
355
backend/internal/conversion/openai/stream_decoder_test.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func makeSSEData(payload string) []byte {
|
||||
return []byte("data: " + payload + "\n\n")
|
||||
}
|
||||
|
||||
func TestStreamDecoder_BasicText(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"object": "chat.completion.chunk",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"content": "你好"},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := makeSSEData(string(data))
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
foundStart := false
|
||||
foundDelta := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventMessageStart {
|
||||
foundStart = true
|
||||
assert.Equal(t, "chatcmpl-1", e.Message.ID)
|
||||
}
|
||||
if e.Type == canonical.EventContentBlockDelta && e.Delta != nil {
|
||||
foundDelta = true
|
||||
assert.Equal(t, "text_delta", e.Delta.Type)
|
||||
assert.Equal(t, "你好", e.Delta.Text)
|
||||
}
|
||||
}
|
||||
assert.True(t, foundStart)
|
||||
assert.True(t, foundDelta)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ToolCalls(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
idx := 0
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"index": &idx,
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"北京\"}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := makeSSEData(string(data))
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
found := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventContentBlockStart && e.ContentBlock != nil && e.ContentBlock.Type == "tool_use" {
|
||||
found = true
|
||||
assert.Equal(t, "call_1", e.ContentBlock.ID)
|
||||
assert.Equal(t, "get_weather", e.ContentBlock.Name)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_Thinking(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"reasoning_content": "思考中",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := makeSSEData(string(data))
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
found := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventContentBlockDelta && e.Delta != nil && e.Delta.Type == "thinking_delta" {
|
||||
found = true
|
||||
assert.Equal(t, "思考中", e.Delta.Thinking)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_FinishReason(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := makeSSEData(string(data))
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
foundStop := false
|
||||
foundMsgStop := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventMessageDelta && e.StopReason != nil {
|
||||
foundStop = true
|
||||
assert.Equal(t, canonical.StopReasonEndTurn, *e.StopReason)
|
||||
}
|
||||
if e.Type == canonical.EventMessageStop {
|
||||
foundMsgStop = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundStop)
|
||||
assert.True(t, foundMsgStop)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_DoneSignal(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
// 先发送一个文本 chunk
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"content": "hi"},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := append(makeSSEData(string(data)), []byte("data: [DONE]\n\n")...)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
// 应该包含 block stop 事件([DONE] 触发 flushOpenBlocks)
|
||||
foundBlockStop := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventContentBlockStop {
|
||||
foundBlockStop = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundBlockStop)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_RefusalReuse(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
// 连续两个 refusal delta chunk
|
||||
for _, text := range []string{"拒绝", "原因"} {
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"refusal": text},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := makeSSEData(string(data))
|
||||
events := d.ProcessChunk(raw)
|
||||
_ = events
|
||||
}
|
||||
|
||||
// 检查只创建了一个 text block(refusal 复用同一个 block)
|
||||
assert.Contains(t, d.openBlocks, d.refusalBlockIndex)
|
||||
}
|
||||
|
||||
func makeChunkSSE(chunk map[string]any) []byte {
|
||||
data, _ := json.Marshal(chunk)
|
||||
return []byte("data: " + string(data) + "\n\n")
|
||||
}
|
||||
|
||||
func TestStreamDecoder_UsageChunk(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-usage",
|
||||
"object": "chat.completion.chunk",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
}
|
||||
raw := makeChunkSSE(chunk)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
found := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventMessageDelta {
|
||||
found = true
|
||||
require.NotNil(t, e.Usage)
|
||||
assert.Equal(t, 100, e.Usage.InputTokens)
|
||||
assert.Equal(t, 50, e.Usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
idx0 := 0
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"index": &idx0,
|
||||
"id": "call_a",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "func_a",
|
||||
"arguments": "{}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
idx1 := 1
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"index": &idx1,
|
||||
"id": "call_b",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "func_b",
|
||||
"arguments": "{}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
events1 := d.ProcessChunk(makeChunkSSE(chunk1))
|
||||
require.NotEmpty(t, events1)
|
||||
|
||||
events2 := d.ProcessChunk(makeChunkSSE(chunk2))
|
||||
require.NotEmpty(t, events2)
|
||||
|
||||
blockIndices := map[int]bool{}
|
||||
for _, e := range append(events1, events2...) {
|
||||
if e.Type == canonical.EventContentBlockStart && e.ContentBlock != nil && e.ContentBlock.Type == "tool_use" {
|
||||
require.NotNil(t, e.Index)
|
||||
blockIndices[*e.Index] = true
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, len(blockIndices), "两个 tool call 应分配不同的 block 索引")
|
||||
}
|
||||
|
||||
func TestStreamDecoder_Flush(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
result := d.Flush()
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"content": "你好"},
|
||||
},
|
||||
},
|
||||
}
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"content": "世界"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
raw := append(makeChunkSSE(chunk1), makeChunkSSE(chunk2)...)
|
||||
events := d.ProcessChunk(raw)
|
||||
|
||||
deltas := []string{}
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventContentBlockDelta && e.Delta != nil && e.Delta.Type == "text_delta" {
|
||||
deltas = append(deltas, e.Delta.Text)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, []string{"你好", "世界"}, deltas)
|
||||
}
|
||||
217
backend/internal/conversion/openai/stream_encoder.go
Normal file
217
backend/internal/conversion/openai/stream_encoder.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamEncoder OpenAI 流式编码器
|
||||
type StreamEncoder struct {
|
||||
bufferedStart *canonical.CanonicalStreamEvent
|
||||
toolCallIndexMap map[string]int
|
||||
nextToolCallIndex int
|
||||
}
|
||||
|
||||
// NewStreamEncoder 创建 OpenAI 流式编码器
|
||||
func NewStreamEncoder() *StreamEncoder {
|
||||
return &StreamEncoder{
|
||||
toolCallIndexMap: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeEvent 编码 Canonical 事件为 SSE chunk
|
||||
func (e *StreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
switch event.Type {
|
||||
case canonical.EventMessageStart:
|
||||
return e.encodeMessageStart(event)
|
||||
case canonical.EventContentBlockStart:
|
||||
return e.bufferBlockStart(event)
|
||||
case canonical.EventContentBlockDelta:
|
||||
return e.encodeContentBlockDelta(event)
|
||||
case canonical.EventContentBlockStop:
|
||||
return nil
|
||||
case canonical.EventMessageDelta:
|
||||
return e.encodeMessageDelta(event)
|
||||
case canonical.EventMessageStop:
|
||||
return [][]byte{[]byte("data: [DONE]\n\n")}
|
||||
case canonical.EventPing, canonical.EventError:
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush 刷新缓冲区
|
||||
func (e *StreamEncoder) Flush() [][]byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeMessageStart 编码消息开始事件
|
||||
func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
id := ""
|
||||
model := ""
|
||||
if event.Message != nil {
|
||||
id = event.Message.ID
|
||||
model = event.Message.Model
|
||||
}
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"role": "assistant"},
|
||||
}},
|
||||
}
|
||||
|
||||
return e.marshalChunk(chunk)
|
||||
}
|
||||
|
||||
// bufferBlockStart 缓冲 block start 事件
|
||||
func (e *StreamEncoder) bufferBlockStart(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
e.bufferedStart = &event
|
||||
if event.ContentBlock != nil && event.ContentBlock.Type == "tool_use" {
|
||||
idx := e.nextToolCallIndex
|
||||
e.nextToolCallIndex++
|
||||
if event.ContentBlock.ID != "" {
|
||||
e.toolCallIndexMap[event.ContentBlock.ID] = idx
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeContentBlockDelta 编码内容块增量事件
|
||||
func (e *StreamEncoder) encodeContentBlockDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Delta == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch canonical.DeltaType(event.Delta.Type) {
|
||||
case canonical.DeltaTypeText:
|
||||
return e.encodeTextDelta(event)
|
||||
case canonical.DeltaTypeInputJSON:
|
||||
return e.encodeInputJSONDelta(event)
|
||||
case canonical.DeltaTypeThinking:
|
||||
return e.encodeThinkingDelta(event)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeTextDelta 编码文本增量
|
||||
func (e *StreamEncoder) encodeTextDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
delta := map[string]any{
|
||||
"content": event.Delta.Text,
|
||||
}
|
||||
if e.bufferedStart != nil {
|
||||
e.bufferedStart = nil
|
||||
}
|
||||
return e.encodeDelta(delta)
|
||||
}
|
||||
|
||||
// encodeInputJSONDelta 编码 JSON 输入增量
|
||||
func (e *StreamEncoder) encodeInputJSONDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if e.bufferedStart != nil && e.bufferedStart.ContentBlock != nil {
|
||||
// 首次 delta,含 id 和 name
|
||||
start := e.bufferedStart.ContentBlock
|
||||
tcIdx := 0
|
||||
if start.ID != "" {
|
||||
tcIdx = e.toolCallIndexMap[start.ID]
|
||||
}
|
||||
delta := map[string]any{
|
||||
"tool_calls": []map[string]any{{
|
||||
"index": tcIdx,
|
||||
"id": start.ID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": start.Name,
|
||||
"arguments": event.Delta.PartialJSON,
|
||||
},
|
||||
}},
|
||||
}
|
||||
e.bufferedStart = nil
|
||||
return e.encodeDelta(delta)
|
||||
}
|
||||
|
||||
// 后续 delta,仅含 arguments
|
||||
// 通过 index 查找 tool call
|
||||
tcIdx := 0
|
||||
if event.Index != nil {
|
||||
for id, idx := range e.toolCallIndexMap {
|
||||
if idx == tcIdx {
|
||||
_ = id
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
delta := map[string]any{
|
||||
"tool_calls": []map[string]any{{
|
||||
"index": tcIdx,
|
||||
"function": map[string]any{
|
||||
"arguments": event.Delta.PartialJSON,
|
||||
},
|
||||
}},
|
||||
}
|
||||
return e.encodeDelta(delta)
|
||||
}
|
||||
|
||||
// encodeThinkingDelta 编码思考增量
|
||||
func (e *StreamEncoder) encodeThinkingDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
delta := map[string]any{
|
||||
"reasoning_content": event.Delta.Thinking,
|
||||
}
|
||||
if e.bufferedStart != nil {
|
||||
e.bufferedStart = nil
|
||||
}
|
||||
return e.encodeDelta(delta)
|
||||
}
|
||||
|
||||
// encodeMessageDelta 编码消息增量事件
|
||||
func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
var chunks [][]byte
|
||||
|
||||
if event.StopReason != nil {
|
||||
fr := mapCanonicalToFinishReason(*event.StopReason)
|
||||
chunk := map[string]any{
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"delta": map[string]any{},
|
||||
"finish_reason": fr,
|
||||
}},
|
||||
}
|
||||
chunks = append(chunks, e.marshalChunk(chunk)...)
|
||||
}
|
||||
|
||||
if event.Usage != nil {
|
||||
chunk := map[string]any{
|
||||
"choices": []map[string]any{},
|
||||
"usage": encodeUsage(*event.Usage),
|
||||
}
|
||||
chunks = append(chunks, e.marshalChunk(chunk)...)
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// encodeDelta 编码 delta 到 SSE chunk
|
||||
func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte {
|
||||
chunk := map[string]any{
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
}},
|
||||
}
|
||||
return e.marshalChunk(chunk)
|
||||
}
|
||||
|
||||
// marshalChunk 序列化 chunk 为 SSE data
|
||||
func (e *StreamEncoder) marshalChunk(chunk map[string]any) [][]byte {
|
||||
data, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return [][]byte{[]byte(fmt.Sprintf("data: %s\n\n", data))}
|
||||
}
|
||||
172
backend/internal/conversion/openai/stream_encoder_test.go
Normal file
172
backend/internal/conversion/openai/stream_encoder_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStreamEncoder_MessageStart(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewMessageStartEvent("chatcmpl-1", "gpt-4")
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "data: "))
|
||||
assert.Contains(t, s, "chatcmpl-1")
|
||||
assert.Contains(t, s, "chat.completion.chunk")
|
||||
|
||||
var payload map[string]any
|
||||
data := strings.TrimPrefix(s, "data: ")
|
||||
data = strings.TrimRight(data, "\n")
|
||||
require.NoError(t, json.Unmarshal([]byte(data), &payload))
|
||||
choices := payload["choices"].([]any)
|
||||
delta := choices[0].(map[string]any)["delta"].(map[string]any)
|
||||
assert.Equal(t, "assistant", delta["role"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_TextDelta(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{Type: "text_delta", Text: "你好"})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "你好")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageStop(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewMessageStopEvent()
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
assert.Equal(t, "data: [DONE]\n\n", string(chunks[0]))
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Buffering(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
|
||||
// ContentBlockStart 应被缓冲,不输出
|
||||
startEvent := canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: ""})
|
||||
chunks := e.EncodeEvent(startEvent)
|
||||
assert.Nil(t, chunks)
|
||||
assert.NotNil(t, e.bufferedStart)
|
||||
|
||||
// 第一个 delta 触发输出(清空缓冲)
|
||||
deltaEvent := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{Type: "text_delta", Text: "hello"})
|
||||
chunks = e.EncodeEvent(deltaEvent)
|
||||
require.NotEmpty(t, chunks)
|
||||
assert.Nil(t, e.bufferedStart)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockStop_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
idx := 0
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventContentBlockStop,
|
||||
Index: &idx,
|
||||
}
|
||||
chunks := e.EncodeEvent(event)
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Ping_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewPingEvent()
|
||||
chunks := e.EncodeEvent(event)
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Error_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewErrorEvent("test_error", "测试错误")
|
||||
chunks := e.EncodeEvent(event)
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Flush_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
chunks := e.Flush()
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ThinkingDelta(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{
|
||||
Type: string(canonical.DeltaTypeThinking),
|
||||
Thinking: "思考内容",
|
||||
})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "reasoning_content")
|
||||
assert.Contains(t, s, "思考内容")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_InputJSONDelta(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
|
||||
e.EncodeEvent(canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: "call_1",
|
||||
Name: "get_weather",
|
||||
}))
|
||||
|
||||
event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{
|
||||
Type: string(canonical.DeltaTypeInputJSON),
|
||||
PartialJSON: "{\"city\":\"北京\"}",
|
||||
})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.NotEmpty(t, chunks)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "tool_calls")
|
||||
assert.Contains(t, s, "北京")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
sr := canonical.StopReasonEndTurn
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventMessageDelta,
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.NotEmpty(t, chunks)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "finish_reason")
|
||||
assert.Contains(t, s, "stop")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
usage := canonical.CanonicalUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
}
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventMessageDelta,
|
||||
Usage: &usage,
|
||||
}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.NotEmpty(t, chunks)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "usage")
|
||||
assert.Contains(t, s, "prompt_tokens")
|
||||
}
|
||||
245
backend/internal/conversion/openai/types.go
Normal file
245
backend/internal/conversion/openai/types.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package openai
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// ChatCompletionRequest OpenAI Chat Completion 请求
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
N *int `json:"n,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
Logprobs *bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
||||
|
||||
// 已废弃字段
|
||||
Functions []FunctionDef `json:"functions,omitempty"`
|
||||
FunctionCall any `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// Message OpenAI 消息
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
|
||||
// 已废弃
|
||||
FunctionCall *FunctionCallMsg `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCall OpenAI 工具调用
|
||||
type ToolCall struct {
|
||||
Index *int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function *FunctionCall `json:"function,omitempty"`
|
||||
Custom *CustomTool `json:"custom,omitempty"`
|
||||
}
|
||||
|
||||
// FunctionCall OpenAI 函数调用
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// CustomTool 自定义工具
|
||||
type CustomTool struct {
|
||||
Name string `json:"name"`
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
// FunctionCallMsg 已废弃的函数调用消息
|
||||
type FunctionCallMsg struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
// Tool OpenAI 工具定义
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function *FunctionDef `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
// FunctionDef OpenAI 函数定义
|
||||
type FunctionDef struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// ResponseFormat OpenAI 响应格式
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
// JSONSchemaDef JSON Schema 定义
|
||||
type JSONSchemaDef struct {
|
||||
Name string `json:"name"`
|
||||
Schema json.RawMessage `json:"schema,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// StreamOptions 流式选项
|
||||
type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
// ChatCompletionResponse OpenAI Chat Completion 响应
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// Choice OpenAI 选择项
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Delta *Message `json:"delta,omitempty"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Logprobs any `json:"logprobs,omitempty"`
|
||||
}
|
||||
|
||||
// Usage OpenAI 用量
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
|
||||
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
// PromptTokensDetails 提示 Token 详情
|
||||
type PromptTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||
AudioTokens int `json:"audio_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// CompletionTokensDetails 完成 Token 详情
|
||||
type CompletionTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
|
||||
AudioTokens int `json:"audio_tokens,omitempty"`
|
||||
AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"`
|
||||
RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// StreamChunk OpenAI 流式 chunk
|
||||
type StreamChunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
}
|
||||
|
||||
// ModelsResponse OpenAI 模型列表响应
|
||||
type ModelsResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []ModelItem `json:"data"`
|
||||
}
|
||||
|
||||
// ModelItem OpenAI 模型项
|
||||
type ModelItem struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
// ModelInfoResponse OpenAI 模型详情响应
|
||||
type ModelInfoResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
// EmbeddingRequest OpenAI 嵌入请求
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"`
|
||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||
Dimensions *int `json:"dimensions,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingResponse OpenAI 嵌入响应
|
||||
type EmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage EmbeddingUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// EmbeddingData 嵌入数据项
|
||||
type EmbeddingData struct {
|
||||
Index int `json:"index"`
|
||||
Embedding any `json:"embedding"`
|
||||
}
|
||||
|
||||
// EmbeddingUsage 嵌入用量
|
||||
type EmbeddingUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// RerankRequest OpenAI 重排序请求
|
||||
type RerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
}
|
||||
|
||||
// RerankResponse OpenAI 重排序响应
|
||||
type RerankResponse struct {
|
||||
Results []RerankResult `json:"results"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// RerankResult 重排序结果项
|
||||
type RerankResult struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
Document *string `json:"document,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorResponse OpenAI 错误响应
|
||||
type ErrorResponse struct {
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail 错误详情
|
||||
type ErrorDetail struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param any `json:"param"`
|
||||
Code string `json:"code"`
|
||||
}
|
||||
19
backend/internal/conversion/provider.go
Normal file
19
backend/internal/conversion/provider.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package conversion
|
||||
|
||||
// TargetProvider 目标上游供应商信息
|
||||
type TargetProvider struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
APIKey string `json:"api_key"`
|
||||
ModelName string `json:"model_name"`
|
||||
AdapterConfig map[string]any `json:"adapter_config,omitempty"`
|
||||
}
|
||||
|
||||
// NewTargetProvider 创建目标供应商
|
||||
func NewTargetProvider(baseURL, apiKey, modelName string) *TargetProvider {
|
||||
return &TargetProvider{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
ModelName: modelName,
|
||||
AdapterConfig: make(map[string]any),
|
||||
}
|
||||
}
|
||||
107
backend/internal/conversion/stream.go
Normal file
107
backend/internal/conversion/stream.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package conversion
|
||||
|
||||
import "nex/backend/internal/conversion/canonical"
|
||||
|
||||
// StreamDecoder 流式解码器接口
|
||||
type StreamDecoder interface {
|
||||
ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent
|
||||
Flush() []canonical.CanonicalStreamEvent
|
||||
}
|
||||
|
||||
// StreamEncoder 流式编码器接口
|
||||
type StreamEncoder interface {
|
||||
EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte
|
||||
Flush() [][]byte
|
||||
}
|
||||
|
||||
// StreamConverter 流式转换器接口
|
||||
type StreamConverter interface {
|
||||
ProcessChunk(rawChunk []byte) [][]byte
|
||||
Flush() [][]byte
|
||||
}
|
||||
|
||||
// PassthroughStreamConverter 同协议透传流式转换器
|
||||
type PassthroughStreamConverter struct{}
|
||||
|
||||
// NewPassthroughStreamConverter 创建透传流式转换器
|
||||
func NewPassthroughStreamConverter() *PassthroughStreamConverter {
|
||||
return &PassthroughStreamConverter{}
|
||||
}
|
||||
|
||||
// ProcessChunk 直接传递原始字节
|
||||
func (c *PassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
return [][]byte{rawChunk}
|
||||
}
|
||||
|
||||
// Flush 无缓冲数据
|
||||
func (c *PassthroughStreamConverter) Flush() [][]byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CanonicalStreamConverter 跨协议规范流式转换器
|
||||
type CanonicalStreamConverter struct {
|
||||
decoder StreamDecoder
|
||||
encoder StreamEncoder
|
||||
chain *MiddlewareChain
|
||||
ctx ConversionContext
|
||||
clientProtocol string
|
||||
providerProtocol string
|
||||
}
|
||||
|
||||
// NewCanonicalStreamConverter 创建规范流式转换器
|
||||
func NewCanonicalStreamConverter(decoder StreamDecoder, encoder StreamEncoder) *CanonicalStreamConverter {
|
||||
return &CanonicalStreamConverter{
|
||||
decoder: decoder,
|
||||
encoder: encoder,
|
||||
}
|
||||
}
|
||||
|
||||
// NewCanonicalStreamConverterWithMiddleware 创建带中间件的规范流式转换器
|
||||
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol string) *CanonicalStreamConverter {
|
||||
return &CanonicalStreamConverter{
|
||||
decoder: decoder,
|
||||
encoder: encoder,
|
||||
chain: chain,
|
||||
ctx: ctx,
|
||||
clientProtocol: clientProtocol,
|
||||
providerProtocol: providerProtocol,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 解码 → 中间件 → 编码管道
|
||||
func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
events := c.decoder.ProcessChunk(rawChunk)
|
||||
var result [][]byte
|
||||
for i := range events {
|
||||
if c.chain != nil {
|
||||
processed, err := c.chain.ApplyStreamEvent(&events[i], c.clientProtocol, c.providerProtocol, &c.ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
events[i] = *processed
|
||||
}
|
||||
chunks := c.encoder.EncodeEvent(events[i])
|
||||
result = append(result, chunks...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Flush 刷新解码器和编码器缓冲区
|
||||
func (c *CanonicalStreamConverter) Flush() [][]byte {
|
||||
events := c.decoder.Flush()
|
||||
var result [][]byte
|
||||
for i := range events {
|
||||
if c.chain != nil {
|
||||
processed, err := c.chain.ApplyStreamEvent(&events[i], c.clientProtocol, c.providerProtocol, &c.ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
events[i] = *processed
|
||||
}
|
||||
chunks := c.encoder.EncodeEvent(events[i])
|
||||
result = append(result, chunks...)
|
||||
}
|
||||
encoderChunks := c.encoder.Flush()
|
||||
result = append(result, encoderChunks...)
|
||||
return result
|
||||
}
|
||||
130
backend/internal/conversion/stream_test.go
Normal file
130
backend/internal/conversion/stream_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPassthroughStreamConverter_ProcessChunk(t *testing.T) {
|
||||
converter := NewPassthroughStreamConverter()
|
||||
data := []byte("hello world")
|
||||
result := converter.ProcessChunk(data)
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, data, result[0])
|
||||
}
|
||||
|
||||
func TestPassthroughStreamConverter_Flush(t *testing.T) {
|
||||
converter := NewPassthroughStreamConverter()
|
||||
result := converter.Flush()
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
// mockStreamDecoder 模拟流式解码器
|
||||
type mockStreamDecoder struct {
|
||||
chunks [][]canonical.CanonicalStreamEvent
|
||||
flush []canonical.CanonicalStreamEvent
|
||||
}
|
||||
|
||||
// ProcessChunk 弹出下一个分片的事件
|
||||
func (d *mockStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
if len(d.chunks) == 0 {
|
||||
return nil
|
||||
}
|
||||
events := d.chunks[0]
|
||||
d.chunks = d.chunks[1:]
|
||||
return events
|
||||
}
|
||||
|
||||
// Flush 返回刷新事件
|
||||
func (d *mockStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||
return d.flush
|
||||
}
|
||||
|
||||
// mockStreamEncoder 模拟流式编码器
|
||||
type mockStreamEncoder struct {
|
||||
events [][]byte
|
||||
flush [][]byte
|
||||
}
|
||||
|
||||
// EncodeEvent 返回编码后的事件
|
||||
func (e *mockStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if len(e.events) == 0 {
|
||||
return nil
|
||||
}
|
||||
return e.events
|
||||
}
|
||||
|
||||
// Flush 返回编码器刷新数据
|
||||
func (e *mockStreamEncoder) Flush() [][]byte {
|
||||
return e.flush
|
||||
}
|
||||
|
||||
func TestCanonicalStreamConverter_ProcessChunk(t *testing.T) {
|
||||
event := canonical.NewMessageStartEvent("id-1", "gpt-4")
|
||||
decoder := &mockStreamDecoder{
|
||||
chunks: [][]canonical.CanonicalStreamEvent{{event}},
|
||||
}
|
||||
encoder := &mockStreamEncoder{
|
||||
events: [][]byte{[]byte("data: test\n\n")},
|
||||
}
|
||||
|
||||
converter := NewCanonicalStreamConverter(decoder, encoder)
|
||||
result := converter.ProcessChunk([]byte("raw"))
|
||||
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, []byte("data: test\n\n"), result[0])
|
||||
}
|
||||
|
||||
func TestCanonicalStreamConverter_WithMiddleware(t *testing.T) {
|
||||
var records []string
|
||||
event := canonical.NewMessageStartEvent("id-1", "gpt-4")
|
||||
decoder := &mockStreamDecoder{
|
||||
chunks: [][]canonical.CanonicalStreamEvent{{event}},
|
||||
}
|
||||
encoder := &mockStreamEncoder{
|
||||
events: [][]byte{[]byte("data: ok\n\n")},
|
||||
}
|
||||
|
||||
chain := NewMiddlewareChain()
|
||||
chain.Use(&recordingMiddleware{name: "mw1", records: &records})
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
|
||||
result := converter.ProcessChunk([]byte("raw"))
|
||||
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, []string{"stream:mw1"}, records)
|
||||
assert.Equal(t, []byte("data: ok\n\n"), result[0])
|
||||
}
|
||||
|
||||
func TestCanonicalStreamConverter_Flush(t *testing.T) {
|
||||
decoder := &mockStreamDecoder{
|
||||
flush: []canonical.CanonicalStreamEvent{
|
||||
canonical.NewMessageStopEvent(),
|
||||
},
|
||||
}
|
||||
encoder := &mockStreamEncoder{
|
||||
events: [][]byte{[]byte("data: stop\n\n")},
|
||||
flush: [][]byte{[]byte("data: flush\n\n")},
|
||||
}
|
||||
|
||||
converter := NewCanonicalStreamConverter(decoder, encoder)
|
||||
result := converter.Flush()
|
||||
|
||||
assert.Len(t, result, 2)
|
||||
assert.Equal(t, []byte("data: stop\n\n"), result[0])
|
||||
assert.Equal(t, []byte("data: flush\n\n"), result[1])
|
||||
}
|
||||
|
||||
func TestCanonicalStreamConverter_EmptyDecoder(t *testing.T) {
|
||||
decoder := &mockStreamDecoder{}
|
||||
encoder := &mockStreamEncoder{}
|
||||
|
||||
converter := NewCanonicalStreamConverter(decoder, encoder)
|
||||
result := converter.ProcessChunk([]byte("raw"))
|
||||
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
@@ -8,6 +8,7 @@ type Provider struct {
|
||||
Name string `json:"name"`
|
||||
APIKey string `json:"api_key"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Protocol string `json:"protocol"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
@@ -1,217 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/anthropic"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// AnthropicHandler Anthropic 协议处理器
|
||||
type AnthropicHandler struct {
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
statsService service.StatsService
|
||||
}
|
||||
|
||||
// NewAnthropicHandler 创建 Anthropic 处理器
|
||||
func NewAnthropicHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *AnthropicHandler {
|
||||
return &AnthropicHandler{
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
statsService: statsService,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleMessages 处理 Messages 请求
|
||||
func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
|
||||
var req anthropic.MessagesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: "无效的请求格式: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 请求验证
|
||||
if validationErrors := anthropic.ValidateRequest(&req); validationErrors != nil {
|
||||
errMsg := formatValidationErrors(validationErrors)
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: errMsg,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.checkMultimodalContent(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
openaiReq, err := anthropic.ConvertRequest(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: "请求转换失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
routeResult, err := h.routingService.Route(openaiReq.Model)
|
||||
if err != nil {
|
||||
h.handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
h.handleStreamRequest(c, openaiReq, routeResult)
|
||||
} else {
|
||||
h.handleNonStreamRequest(c, openaiReq, routeResult)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
openaiResp, err := h.client.SendRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
anthropicResp, err := anthropic.ConvertResponse(openaiResp)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: "响应转换失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model)
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, anthropicResp)
|
||||
}
|
||||
|
||||
func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
converter := anthropic.NewStreamConverter(
|
||||
fmt.Sprintf("msg_%s", routeResult.Provider.ID),
|
||||
openaiReq.Model,
|
||||
)
|
||||
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if event.Done {
|
||||
break
|
||||
}
|
||||
|
||||
chunk, err := openai.NewAdapter().ParseStreamChunk(event.Data)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
anthropicEvents, err := converter.ConvertChunk(chunk)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ae := range anthropicEvents {
|
||||
eventStr, err := anthropic.SerializeEvent(ae)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
writer.WriteString(eventStr)
|
||||
writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model)
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest) error {
|
||||
for _, msg := range req.Messages {
|
||||
for _, block := range msg.Content {
|
||||
if block.Type == "image" {
|
||||
return fmt.Errorf("MVP 不支持多模态内容(图片)")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *AnthropicHandler) handleError(c *gin.Context, err error) {
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
c.JSON(appErr.HTTPStatus, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "not_found_error",
|
||||
Message: appErr.Message,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "internal_error",
|
||||
Message: "内部错误: " + err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -13,7 +15,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
@@ -34,8 +35,8 @@ func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error
|
||||
}
|
||||
|
||||
type mockStatsService struct {
|
||||
err error
|
||||
stats []domain.UsageStats
|
||||
err error
|
||||
stats []domain.UsageStats
|
||||
aggrResult []map[string]interface{}
|
||||
}
|
||||
|
||||
@@ -84,61 +85,14 @@ func (m *mockModelService) Update(id string, updates map[string]interface{}) err
|
||||
func (m *mockModelService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockProviderClient struct {
|
||||
resp *openai.ChatCompletionResponse
|
||||
eventChan chan provider.StreamEvent
|
||||
err error
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderClient) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) {
|
||||
return m.resp, m.err
|
||||
func (m *mockProviderClient) Send(ctx context.Context, spec interface{}) (interface{}, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
func (m *mockProviderClient) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan provider.StreamEvent, error) {
|
||||
return m.eventChan, m.err
|
||||
}
|
||||
|
||||
// ============ OpenAI Handler 测试 ============
|
||||
|
||||
func TestOpenAIHandler_HandleChatCompletions_InvalidJSON(t *testing.T) {
|
||||
h := NewOpenAIHandler(nil, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte("invalid")))
|
||||
|
||||
h.HandleChatCompletions(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestOpenAIHandler_HandleChatCompletions_ValidationError(t *testing.T) {
|
||||
h := NewOpenAIHandler(nil, nil, nil)
|
||||
|
||||
// 缺少 model 字段
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"messages": []map[string]string{{"role": "user", "content": "hi"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.HandleChatCompletions(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestOpenAIHandler_HandleChatCompletions_RouteError(t *testing.T) {
|
||||
routingSvc := &mockRoutingService{err: appErrors.ErrModelNotFound}
|
||||
h := NewOpenAIHandler(nil, routingSvc, nil)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"model": "nonexistent",
|
||||
"messages": []map[string]string{{"role": "user", "content": "hi"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.HandleChatCompletions(c)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
func (m *mockProviderClient) SendStream(ctx context.Context, spec interface{}) (<-chan provider.StreamEvent, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
// ============ Provider Handler 测试 ============
|
||||
@@ -283,8 +237,16 @@ func TestFormatValidationErrors(t *testing.T) {
|
||||
"model": "模型名称不能为空",
|
||||
"messages": "消息列表不能为空",
|
||||
}
|
||||
result := formatValidationErrors(errs)
|
||||
result := formatMapErrors(errs)
|
||||
require.Contains(t, result, "请求验证失败")
|
||||
require.Contains(t, result, "model")
|
||||
require.Contains(t, result, "messages")
|
||||
}
|
||||
|
||||
func formatMapErrors(errs map[string]string) string {
|
||||
parts := make([]string, 0, len(errs))
|
||||
for field, msg := range errs {
|
||||
parts = append(parts, fmt.Sprintf("%s: %s", field, msg))
|
||||
}
|
||||
return "请求验证失败: " + strings.Join(parts, "; ")
|
||||
}
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// OpenAIHandler OpenAI 协议处理器
|
||||
type OpenAIHandler struct {
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
statsService service.StatsService
|
||||
}
|
||||
|
||||
// NewOpenAIHandler 创建 OpenAI 处理器
|
||||
func NewOpenAIHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *OpenAIHandler {
|
||||
return &OpenAIHandler{
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
statsService: statsService,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleChatCompletions 处理 Chat Completions 请求
|
||||
func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
|
||||
var req openai.ChatCompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "无效的请求格式: " + err.Error(),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 请求验证
|
||||
if validationErrors := openai.ValidateRequest(&req); validationErrors != nil {
|
||||
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: formatValidationErrors(validationErrors),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
routeResult, err := h.routingService.Route(req.Model)
|
||||
if err != nil {
|
||||
h.handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
h.handleStreamRequest(c, &req, routeResult)
|
||||
} else {
|
||||
h.handleNonStreamRequest(c, &req, routeResult)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
Type: "api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
Type: "api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if event.Done {
|
||||
writer.WriteString("data: [DONE]\n\n")
|
||||
writer.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
writer.WriteString("data: ")
|
||||
writer.Write(event.Data)
|
||||
writer.WriteString("\n\n")
|
||||
writer.Flush()
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *OpenAIHandler) handleError(c *gin.Context, err error) {
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
c.JSON(appErr.HTTPStatus, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: appErr.Message,
|
||||
Type: "invalid_request_error",
|
||||
Code: appErr.Code,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "内部错误: " + err.Error(),
|
||||
Type: "internal_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// formatValidationErrors 将验证错误 map 格式化为字符串
|
||||
func formatValidationErrors(errors map[string]string) string {
|
||||
parts := make([]string, 0, len(errors))
|
||||
for field, msg := range errors {
|
||||
parts = append(parts, fmt.Sprintf("%s: %s", field, msg))
|
||||
}
|
||||
return "请求验证失败: " + strings.Join(parts, "; ")
|
||||
}
|
||||
@@ -26,10 +26,11 @@ func NewProviderHandler(providerService service.ProviderService) *ProviderHandle
|
||||
// CreateProvider 创建供应商
|
||||
func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
var req struct {
|
||||
ID string `json:"id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
ID string `json:"id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Protocol string `json:"protocol"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
@@ -39,11 +40,17 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
protocol := req.Protocol
|
||||
if protocol == "" {
|
||||
protocol = "openai"
|
||||
}
|
||||
|
||||
provider := &domain.Provider{
|
||||
ID: req.ID,
|
||||
Name: req.Name,
|
||||
APIKey: req.APIKey,
|
||||
BaseURL: req.BaseURL,
|
||||
ID: req.ID,
|
||||
Name: req.Name,
|
||||
APIKey: req.APIKey,
|
||||
BaseURL: req.BaseURL,
|
||||
Protocol: protocol,
|
||||
}
|
||||
|
||||
err := h.providerService.Create(provider)
|
||||
|
||||
371
backend/internal/handler/proxy_handler.go
Normal file
371
backend/internal/handler/proxy_handler.go
Normal file
@@ -0,0 +1,371 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ProxyHandler 统一代理处理器
|
||||
type ProxyHandler struct {
|
||||
engine *conversion.ConversionEngine
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
providerService service.ProviderService
|
||||
statsService service.StatsService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewProxyHandler 创建统一代理处理器
|
||||
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService) *ProxyHandler {
|
||||
return &ProxyHandler{
|
||||
engine: engine,
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
providerService: providerService,
|
||||
statsService: statsService,
|
||||
logger: zap.L(),
|
||||
}
|
||||
}
|
||||
|
||||
// HandleProxy 处理代理请求
|
||||
func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
// 从 URL 提取 clientProtocol: /{protocol}/v1/...
|
||||
clientProtocol := c.Param("protocol")
|
||||
if clientProtocol == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"})
|
||||
return
|
||||
}
|
||||
|
||||
// 原始路径: /v1/{path}
|
||||
path := c.Param("path")
|
||||
if strings.HasPrefix(path, "/") {
|
||||
path = path[1:]
|
||||
}
|
||||
nativePath := "/v1/" + path
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 model 名称(从 JSON body 中提取,GET 请求无 body)
|
||||
modelName := ""
|
||||
if len(body) > 0 {
|
||||
modelName = extractModelName(body)
|
||||
}
|
||||
|
||||
// 构建输入 HTTPRequestSpec
|
||||
inSpec := conversion.HTTPRequestSpec{
|
||||
URL: nativePath,
|
||||
Method: c.Request.Method,
|
||||
Headers: extractHeaders(c),
|
||||
Body: body,
|
||||
}
|
||||
|
||||
// 路由
|
||||
routeResult, err := h.routingService.Route(modelName)
|
||||
if err != nil {
|
||||
// GET 请求或无法提取 model 时,直接转发到上游
|
||||
if len(body) == 0 || modelName == "" {
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol)
|
||||
return
|
||||
}
|
||||
h.writeError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 确定 providerProtocol
|
||||
providerProtocol := routeResult.Provider.Protocol
|
||||
if providerProtocol == "" {
|
||||
providerProtocol = "openai"
|
||||
}
|
||||
|
||||
// 构建 TargetProvider
|
||||
targetProvider := conversion.NewTargetProvider(
|
||||
routeResult.Provider.BaseURL,
|
||||
routeResult.Provider.APIKey,
|
||||
routeResult.Model.ModelName,
|
||||
)
|
||||
|
||||
// 判断是否流式
|
||||
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
|
||||
|
||||
if isStream {
|
||||
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
|
||||
} else {
|
||||
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStream 处理非流式请求
|
||||
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
|
||||
// 转换请求
|
||||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||
if err != nil {
|
||||
h.logger.Error("转换请求失败", zap.String("error", err.Error()))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.logger.Error("发送请求失败", zap.String("error", err.Error()))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 转换响应
|
||||
interfaceType, _ := h.engine.DetectInterfaceType(inSpec.URL, clientProtocol)
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, interfaceType)
|
||||
if err != nil {
|
||||
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
for k, v := range convertedResp.Headers {
|
||||
c.Header(k, v)
|
||||
}
|
||||
if c.GetHeader("Content-Type") == "" {
|
||||
c.Header("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
|
||||
}()
|
||||
}
|
||||
|
||||
// handleStream 处理流式请求
|
||||
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
|
||||
// 转换请求
|
||||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建流式转换器
|
||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 发送流式请求
|
||||
eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
h.logger.Error("流读取错误", zap.String("error", event.Error.Error()))
|
||||
break
|
||||
}
|
||||
if event.Done {
|
||||
// flush 转换器
|
||||
chunks := streamConverter.Flush()
|
||||
for _, chunk := range chunks {
|
||||
writer.Write(chunk)
|
||||
writer.Flush()
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
chunks := streamConverter.ProcessChunk(event.Data)
|
||||
for _, chunk := range chunks {
|
||||
writer.Write(chunk)
|
||||
writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
|
||||
}()
|
||||
}
|
||||
|
||||
// isStreamRequest 判断是否流式请求
|
||||
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
|
||||
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol)
|
||||
if ifaceType != conversion.InterfaceTypeChat {
|
||||
return false
|
||||
}
|
||||
for i, b := range body {
|
||||
if b == '"' && i+8 <= len(body) {
|
||||
if string(body[i:i+8]) == `"stream"` {
|
||||
for j := i + 8; j < len(body) && j < i+20; j++ {
|
||||
if body[j] == 't' && j+3 < len(body) && string(body[j:j+4]) == "true" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// writeConversionError 写入转换错误
|
||||
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
|
||||
if convErr, ok := err.(*conversion.ConversionError); ok {
|
||||
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol)
|
||||
c.Data(statusCode, "application/json", body)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
// writeError 写入路由错误
|
||||
func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
|
||||
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) {
|
||||
registry := h.engine.GetRegistry()
|
||||
adapter, err := registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.providerService.List()
|
||||
if err != nil || len(providers) == 0 {
|
||||
h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"})
|
||||
return
|
||||
}
|
||||
|
||||
p := providers[0]
|
||||
providerProtocol := p.Protocol
|
||||
if providerProtocol == "" {
|
||||
providerProtocol = "openai"
|
||||
}
|
||||
|
||||
ifaceType := adapter.DetectInterfaceType(inSpec.URL)
|
||||
|
||||
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
|
||||
|
||||
var outSpec *conversion.HTTPRequestSpec
|
||||
if clientProtocol == providerProtocol {
|
||||
upstreamURL := p.BaseURL + inSpec.URL
|
||||
headers := adapter.BuildHeaders(targetProvider)
|
||||
if _, ok := headers["Content-Type"]; !ok {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
outSpec = &conversion.HTTPRequestSpec{
|
||||
URL: upstreamURL,
|
||||
Method: inSpec.Method,
|
||||
Headers: headers,
|
||||
Body: inSpec.Body,
|
||||
}
|
||||
} else {
|
||||
outSpec, err = h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range convertedResp.Headers {
|
||||
c.Header(k, v)
|
||||
}
|
||||
if c.GetHeader("Content-Type") == "" {
|
||||
c.Header("Content-Type", "application/json")
|
||||
}
|
||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
||||
}
|
||||
|
||||
// extractModelName 从 JSON body 中提取 model
|
||||
func extractModelName(body []byte) string {
|
||||
inQuote := false
|
||||
escaped := false
|
||||
keyStart := -1
|
||||
keyEnd := -1
|
||||
lookingForKey := true
|
||||
lookingForValue := false
|
||||
valueStart := -1
|
||||
|
||||
for i := 0; i < len(body); i++ {
|
||||
b := body[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if b == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if b == '"' {
|
||||
if !inQuote {
|
||||
inQuote = true
|
||||
if lookingForKey {
|
||||
keyStart = i + 1
|
||||
}
|
||||
if lookingForValue {
|
||||
valueStart = i + 1
|
||||
}
|
||||
} else {
|
||||
inQuote = false
|
||||
if lookingForKey && keyStart >= 0 {
|
||||
keyEnd = i
|
||||
if string(body[keyStart:keyEnd]) == "model" {
|
||||
lookingForKey = false
|
||||
lookingForValue = true
|
||||
}
|
||||
} else if lookingForValue && valueStart >= 0 {
|
||||
return string(body[valueStart:i])
|
||||
}
|
||||
}
|
||||
}
|
||||
if !inQuote && lookingForValue && b == ':' {
|
||||
// 等待值开始
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractHeaders 从 Gin context 提取请求头
|
||||
func extractHeaders(c *gin.Context) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
for k, vs := range c.Request.Header {
|
||||
if len(vs) > 0 {
|
||||
headers[k] = vs[0]
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
@@ -1,234 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"nex/backend/internal/protocol/openai"
|
||||
)
|
||||
|
||||
// ConvertRequest 将 Anthropic 请求转换为 OpenAI 请求
|
||||
func ConvertRequest(anthropicReq *MessagesRequest) (*openai.ChatCompletionRequest, error) {
|
||||
openaiReq := &openai.ChatCompletionRequest{
|
||||
Model: anthropicReq.Model,
|
||||
Temperature: anthropicReq.Temperature,
|
||||
TopP: anthropicReq.TopP,
|
||||
Stream: anthropicReq.Stream,
|
||||
}
|
||||
|
||||
// 处理 max_tokens(Anthropic 要求必须有,默认 4096)
|
||||
if anthropicReq.MaxTokens > 0 {
|
||||
openaiReq.MaxTokens = &anthropicReq.MaxTokens
|
||||
} else {
|
||||
defaultMax := 4096
|
||||
openaiReq.MaxTokens = &defaultMax
|
||||
}
|
||||
|
||||
// 处理 stop_sequences
|
||||
if len(anthropicReq.StopSequences) > 0 {
|
||||
openaiReq.Stop = anthropicReq.StopSequences
|
||||
}
|
||||
|
||||
// 转换 system 消息
|
||||
messages := make([]openai.Message, 0)
|
||||
if anthropicReq.System != "" {
|
||||
messages = append(messages, openai.Message{
|
||||
Role: "system",
|
||||
Content: anthropicReq.System,
|
||||
})
|
||||
}
|
||||
|
||||
// 转换 messages
|
||||
for _, msg := range anthropicReq.Messages {
|
||||
openaiMsg, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, openaiMsg...)
|
||||
}
|
||||
openaiReq.Messages = messages
|
||||
|
||||
// 转换 tools
|
||||
if len(anthropicReq.Tools) > 0 {
|
||||
openaiReq.Tools = make([]openai.Tool, len(anthropicReq.Tools))
|
||||
for i, tool := range anthropicReq.Tools {
|
||||
openaiReq.Tools[i] = openai.Tool{
|
||||
Type: "function",
|
||||
Function: openai.FunctionDefinition{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: tool.InputSchema,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 转换 tool_choice
|
||||
if anthropicReq.ToolChoice != nil {
|
||||
toolChoice, err := convertToolChoice(anthropicReq.ToolChoice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
openaiReq.ToolChoice = toolChoice
|
||||
}
|
||||
|
||||
return openaiReq, nil
|
||||
}
|
||||
|
||||
// ConvertResponse 将 OpenAI 响应转换为 Anthropic 响应
|
||||
func ConvertResponse(openaiResp *openai.ChatCompletionResponse) (*MessagesResponse, error) {
|
||||
anthropicResp := &MessagesResponse{
|
||||
ID: openaiResp.ID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: openaiResp.Model,
|
||||
Usage: Usage{
|
||||
InputTokens: openaiResp.Usage.PromptTokens,
|
||||
OutputTokens: openaiResp.Usage.CompletionTokens,
|
||||
},
|
||||
}
|
||||
|
||||
// 转换 content
|
||||
if len(openaiResp.Choices) > 0 {
|
||||
choice := openaiResp.Choices[0]
|
||||
content := make([]ContentBlock, 0)
|
||||
|
||||
if choice.Message != nil {
|
||||
// 文本内容
|
||||
if choice.Message.Content != "" {
|
||||
if str, ok := choice.Message.Content.(string); ok && str != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "text",
|
||||
Text: str,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Tool calls
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
// 解析 arguments JSON
|
||||
var input interface{}
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &input); err != nil {
|
||||
return nil, fmt.Errorf("解析 tool_call arguments 失败: %w", err)
|
||||
}
|
||||
|
||||
content = append(content, ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: input,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anthropicResp.Content = content
|
||||
|
||||
// 转换 finish_reason
|
||||
switch choice.FinishReason {
|
||||
case "stop":
|
||||
anthropicResp.StopReason = "end_turn"
|
||||
case "tool_calls":
|
||||
anthropicResp.StopReason = "tool_use"
|
||||
case "length":
|
||||
anthropicResp.StopReason = "max_tokens"
|
||||
}
|
||||
}
|
||||
|
||||
return anthropicResp, nil
|
||||
}
|
||||
|
||||
// convertMessage 转换单条消息
|
||||
func convertMessage(msg AnthropicMessage) ([]openai.Message, error) {
|
||||
var messages []openai.Message
|
||||
|
||||
// 处理 content
|
||||
for _, block := range msg.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
// 文本内容
|
||||
messages = append(messages, openai.Message{
|
||||
Role: msg.Role,
|
||||
Content: block.Text,
|
||||
})
|
||||
|
||||
case "tool_result":
|
||||
// 工具结果
|
||||
content := ""
|
||||
if str, ok := block.Content.(string); ok {
|
||||
content = str
|
||||
} else {
|
||||
// 如果是数组或其他类型,序列化为 JSON
|
||||
bytes, err := json.Marshal(block.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化 tool_result 内容失败: %w", err)
|
||||
}
|
||||
content = string(bytes)
|
||||
}
|
||||
|
||||
messages = append(messages, openai.Message{
|
||||
Role: "tool",
|
||||
Content: content,
|
||||
ToolCallID: block.ToolUseID,
|
||||
})
|
||||
|
||||
case "image":
|
||||
// MVP 不支持多模态
|
||||
return nil, fmt.Errorf("MVP 不支持多模态内容(图片)")
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("未知的内容块类型: %s", block.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有 content,创建空消息(不应该发生)
|
||||
if len(messages) == 0 {
|
||||
messages = append(messages, openai.Message{
|
||||
Role: msg.Role,
|
||||
Content: "",
|
||||
})
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// convertToolChoice 转换工具选择
|
||||
func convertToolChoice(choice interface{}) (interface{}, error) {
|
||||
// 如果是字符串
|
||||
if str, ok := choice.(string); ok {
|
||||
// "auto" 或 "any" 都映射为 "auto"
|
||||
if str == "auto" || str == "any" {
|
||||
return "auto", nil
|
||||
}
|
||||
return nil, fmt.Errorf("无效的 tool_choice 字符串: %s", str)
|
||||
}
|
||||
|
||||
// 如果是对象
|
||||
if obj, ok := choice.(map[string]interface{}); ok {
|
||||
choiceType, ok := obj["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tool_choice 对象缺少 type 字段")
|
||||
}
|
||||
|
||||
switch choiceType {
|
||||
case "auto", "any":
|
||||
return "auto", nil
|
||||
case "tool":
|
||||
name, ok := obj["name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tool_choice type=tool 缺少 name 字段")
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"type": "function",
|
||||
"function": map[string]string{
|
||||
"name": name,
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("无效的 tool_choice type: %s", choiceType)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("tool_choice 格式无效")
|
||||
}
|
||||
@@ -1,270 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/protocol/openai"
|
||||
)
|
||||
|
||||
func TestConvertRequest_Basic(t *testing.T) {
|
||||
temp := 0.7
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 1024,
|
||||
Temperature: &temp,
|
||||
Messages: []AnthropicMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []ContentBlock{
|
||||
{Type: "text", Text: "Hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := ConvertRequest(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-3-opus", result.Model)
|
||||
assert.Equal(t, 1024, *result.MaxTokens)
|
||||
assert.Equal(t, &temp, result.Temperature)
|
||||
require.Len(t, result.Messages, 1)
|
||||
assert.Equal(t, "user", result.Messages[0].Role)
|
||||
assert.Equal(t, "Hello", result.Messages[0].Content)
|
||||
}
|
||||
|
||||
func TestConvertRequest_WithSystem(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 100,
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []AnthropicMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []ContentBlock{{Type: "text", Text: "Hi"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := ConvertRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.Messages, 2)
|
||||
assert.Equal(t, "system", result.Messages[0].Role)
|
||||
assert.Equal(t, "You are a helpful assistant.", result.Messages[0].Content)
|
||||
assert.Equal(t, "user", result.Messages[1].Role)
|
||||
}
|
||||
|
||||
func TestConvertRequest_DefaultMaxTokens(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 0, // 未设置
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := ConvertRequest(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 4096, *result.MaxTokens)
|
||||
}
|
||||
|
||||
func TestConvertRequest_WithTools(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 100,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
|
||||
},
|
||||
Tools: []AnthropicTool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather info",
|
||||
InputSchema: map[string]interface{}{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := ConvertRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.Tools, 1)
|
||||
assert.Equal(t, "function", result.Tools[0].Type)
|
||||
assert.Equal(t, "get_weather", result.Tools[0].Function.Name)
|
||||
}
|
||||
|
||||
func TestConvertRequest_WithStopSequences(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 100,
|
||||
StopSequences: []string{"STOP", "END"},
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := ConvertRequest(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"STOP", "END"}, result.Stop)
|
||||
}
|
||||
|
||||
func TestConvertRequest_ToolResult(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 100,
|
||||
Messages: []AnthropicMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "tool_result",
|
||||
ToolUseID: "tool_123",
|
||||
Content: "result data",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := ConvertRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.Messages, 1)
|
||||
assert.Equal(t, "tool", result.Messages[0].Role)
|
||||
assert.Equal(t, "tool_123", result.Messages[0].ToolCallID)
|
||||
assert.Equal(t, "result data", result.Messages[0].Content)
|
||||
}
|
||||
|
||||
func TestConvertResponse(t *testing.T) {
|
||||
resp := &openai.ChatCompletionResponse{
|
||||
ID: "chatcmpl-123",
|
||||
Model: "gpt-4",
|
||||
Choices: []openai.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: &openai.Message{Role: "assistant", Content: "Hello!"},
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
Usage: openai.Usage{PromptTokens: 10, CompletionTokens: 5},
|
||||
}
|
||||
|
||||
result, err := ConvertResponse(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "chatcmpl-123", result.ID)
|
||||
assert.Equal(t, "message", result.Type)
|
||||
assert.Equal(t, "assistant", result.Role)
|
||||
assert.Equal(t, "end_turn", result.StopReason)
|
||||
require.Len(t, result.Content, 1)
|
||||
assert.Equal(t, "text", result.Content[0].Type)
|
||||
assert.Equal(t, "Hello!", result.Content[0].Text)
|
||||
assert.Equal(t, 10, result.Usage.InputTokens)
|
||||
assert.Equal(t, 5, result.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestConvertResponse_ToolCalls(t *testing.T) {
|
||||
args, _ := json.Marshal(map[string]interface{}{"city": "Beijing"})
|
||||
resp := &openai.ChatCompletionResponse{
|
||||
ID: "chatcmpl-456",
|
||||
Model: "gpt-4",
|
||||
Choices: []openai.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: &openai.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []openai.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Type: "function",
|
||||
Function: openai.FunctionCall{
|
||||
Name: "get_weather",
|
||||
Arguments: string(args),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
FinishReason: "tool_calls",
|
||||
},
|
||||
},
|
||||
Usage: openai.Usage{},
|
||||
}
|
||||
|
||||
result, err := ConvertResponse(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "tool_use", result.StopReason)
|
||||
require.Len(t, result.Content, 1)
|
||||
assert.Equal(t, "tool_use", result.Content[0].Type)
|
||||
assert.Equal(t, "call_123", result.Content[0].ID)
|
||||
assert.Equal(t, "get_weather", result.Content[0].Name)
|
||||
}
|
||||
|
||||
func TestConvertToolChoice_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
wantErr bool
|
||||
check func(interface{})
|
||||
}{
|
||||
{"auto字符串", "auto", false, func(r interface{}) { assert.Equal(t, "auto", r) }},
|
||||
{"any字符串", "any", false, func(r interface{}) { assert.Equal(t, "auto", r) }},
|
||||
{"无效字符串", "invalid", true, nil},
|
||||
{"tool对象", map[string]interface{}{"type": "tool", "name": "my_func"}, false,
|
||||
func(r interface{}) {
|
||||
m := r.(map[string]interface{})
|
||||
assert.Equal(t, "function", m["type"])
|
||||
}},
|
||||
{"缺少name的tool对象", map[string]interface{}{"type": "tool"}, true, nil},
|
||||
{"缺少type的对象", map[string]interface{}{"name": "func"}, true, nil},
|
||||
{"无效类型", 42, true, nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := convertToolChoice(tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
tt.check(result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRequest(t *testing.T) {
|
||||
t.Run("有效请求", func(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 100,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
|
||||
},
|
||||
}
|
||||
errs := ValidateRequest(req)
|
||||
assert.Nil(t, errs)
|
||||
})
|
||||
|
||||
t.Run("缺少模型", func(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
MaxTokens: 100,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
|
||||
},
|
||||
}
|
||||
errs := ValidateRequest(req)
|
||||
assert.NotNil(t, errs)
|
||||
assert.Contains(t, errs["model"], "不能为空")
|
||||
})
|
||||
|
||||
t.Run("MaxTokens为0", func(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 0,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
|
||||
},
|
||||
}
|
||||
errs := ValidateRequest(req)
|
||||
assert.NotNil(t, errs)
|
||||
})
|
||||
}
|
||||
@@ -1,164 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"nex/backend/internal/protocol/openai"
|
||||
)
|
||||
|
||||
// StreamConverter 流式转换器
|
||||
type StreamConverter struct {
|
||||
messageID string
|
||||
model string
|
||||
index int // 当前 content block index
|
||||
toolCallArgs map[int]string // 缓存每个 tool_call 的 arguments
|
||||
sentStart bool // 是否已发送 message_start
|
||||
sentBlockStart map[int]bool // 每个 index 是否已发送 content_block_start
|
||||
}
|
||||
|
||||
// NewStreamConverter 创建流式转换器
|
||||
func NewStreamConverter(messageID, model string) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
messageID: messageID,
|
||||
model: model,
|
||||
index: 0,
|
||||
toolCallArgs: make(map[int]string),
|
||||
sentStart: false,
|
||||
sentBlockStart: make(map[int]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertChunk 转换 OpenAI 流块为 Anthropic 事件
|
||||
func (c *StreamConverter) ConvertChunk(chunk *openai.StreamChunk) ([]StreamEvent, error) {
|
||||
var events []StreamEvent
|
||||
|
||||
// 发送 message_start(仅一次)
|
||||
if !c.sentStart {
|
||||
events = append(events, StreamEvent{
|
||||
Type: "message_start",
|
||||
Message: &MessagesResponse{
|
||||
ID: c.messageID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: c.model,
|
||||
Content: []ContentBlock{},
|
||||
Usage: Usage{
|
||||
InputTokens: 0,
|
||||
OutputTokens: 0,
|
||||
},
|
||||
},
|
||||
})
|
||||
c.sentStart = true
|
||||
}
|
||||
|
||||
// 处理每个 choice
|
||||
for _, choice := range chunk.Choices {
|
||||
// 处理 content delta
|
||||
if choice.Delta.Content != "" {
|
||||
// 发送 content_block_start(如果还没发送)
|
||||
if !c.sentBlockStart[c.index] {
|
||||
events = append(events, StreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.index,
|
||||
ContentBlock: &ContentBlock{
|
||||
Type: "text",
|
||||
},
|
||||
})
|
||||
c.sentBlockStart[c.index] = true
|
||||
}
|
||||
|
||||
// 发送 text delta
|
||||
events = append(events, StreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.index,
|
||||
Delta: &Delta{
|
||||
Type: "text_delta",
|
||||
Text: choice.Delta.Content,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// 处理 tool_calls delta
|
||||
if len(choice.Delta.ToolCalls) > 0 {
|
||||
for _, tc := range choice.Delta.ToolCalls {
|
||||
// 确定 tool_call index
|
||||
toolIndex := c.index + len(c.toolCallArgs)
|
||||
|
||||
// 发送 content_block_start(如果还没发送)
|
||||
if !c.sentBlockStart[toolIndex] {
|
||||
events = append(events, StreamEvent{
|
||||
Type: "content_block_start",
|
||||
Index: toolIndex,
|
||||
ContentBlock: &ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
},
|
||||
})
|
||||
c.sentBlockStart[toolIndex] = true
|
||||
c.toolCallArgs[toolIndex] = ""
|
||||
}
|
||||
|
||||
// 缓存 arguments
|
||||
c.toolCallArgs[toolIndex] += tc.Function.Arguments
|
||||
|
||||
// 发送 input delta
|
||||
events = append(events, StreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: toolIndex,
|
||||
Delta: &Delta{
|
||||
Type: "input_json_delta",
|
||||
Input: tc.Function.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 finish_reason
|
||||
if choice.FinishReason != "" {
|
||||
// 发送 content_block_stop
|
||||
for idx := range c.sentBlockStart {
|
||||
events = append(events, StreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: idx,
|
||||
})
|
||||
}
|
||||
|
||||
// 转换 stop_reason
|
||||
stopReason := ""
|
||||
switch choice.FinishReason {
|
||||
case "stop":
|
||||
stopReason = "end_turn"
|
||||
case "tool_calls":
|
||||
stopReason = "tool_use"
|
||||
case "length":
|
||||
stopReason = "max_tokens"
|
||||
}
|
||||
|
||||
// 发送 message_delta
|
||||
events = append(events, StreamEvent{
|
||||
Type: "message_delta",
|
||||
Delta: &Delta{
|
||||
StopReason: stopReason,
|
||||
},
|
||||
})
|
||||
|
||||
// 发送 message_stop
|
||||
events = append(events, StreamEvent{
|
||||
Type: "message_stop",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// SerializeEvent 序列化事件为 SSE 格式
|
||||
func SerializeEvent(event StreamEvent) (string, error) {
|
||||
bytes, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("event: %s\ndata: %s\n\n", event.Type, string(bytes)), nil
|
||||
}
|
||||
@@ -1,229 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/protocol/openai"
|
||||
)
|
||||
|
||||
func TestStreamConverter_MessageStart(t *testing.T) {
|
||||
converter := NewStreamConverter("msg_123", "claude-3-opus")
|
||||
|
||||
chunk := &openai.StreamChunk{
|
||||
ID: "chatcmpl-123",
|
||||
Choices: []openai.StreamChoice{{Index: 0, Delta: openai.Delta{}}},
|
||||
}
|
||||
|
||||
events, err := converter.ConvertChunk(chunk)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
// 第一个事件应该是 message_start
|
||||
assert.Equal(t, "message_start", events[0].Type)
|
||||
require.NotNil(t, events[0].Message)
|
||||
assert.Equal(t, "msg_123", events[0].Message.ID)
|
||||
assert.Equal(t, "message", events[0].Message.Type)
|
||||
assert.Equal(t, "assistant", events[0].Message.Role)
|
||||
assert.Equal(t, "claude-3-opus", events[0].Message.Model)
|
||||
}
|
||||
|
||||
func TestStreamConverter_TextDelta(t *testing.T) {
|
||||
converter := NewStreamConverter("msg_123", "claude-3-opus")
|
||||
|
||||
// 先发送一个空块以触发 message_start
|
||||
chunk1 := &openai.StreamChunk{
|
||||
Choices: []openai.StreamChoice{
|
||||
{Delta: openai.Delta{Content: "Hello"}},
|
||||
},
|
||||
}
|
||||
events1, err := converter.ConvertChunk(chunk1)
|
||||
require.NoError(t, err)
|
||||
// 应有 message_start + content_block_start + text delta
|
||||
assert.GreaterOrEqual(t, len(events1), 3)
|
||||
|
||||
// 第二个文本块不应再发送 message_start 和 content_block_start
|
||||
chunk2 := &openai.StreamChunk{
|
||||
Choices: []openai.StreamChoice{
|
||||
{Delta: openai.Delta{Content: " world"}},
|
||||
},
|
||||
}
|
||||
events2, err := converter.ConvertChunk(chunk2)
|
||||
require.NoError(t, err)
|
||||
// 只有 text delta
|
||||
assert.Len(t, events2, 1)
|
||||
assert.Equal(t, "content_block_delta", events2[0].Type)
|
||||
assert.Equal(t, "text_delta", events2[0].Delta.Type)
|
||||
assert.Equal(t, " world", events2[0].Delta.Text)
|
||||
}
|
||||
|
||||
func TestStreamConverter_FinishReason(t *testing.T) {
|
||||
converter := NewStreamConverter("msg_123", "claude-3-opus")
|
||||
|
||||
chunk := &openai.StreamChunk{
|
||||
Choices: []openai.StreamChoice{
|
||||
{Delta: openai.Delta{Content: "Hello"}, FinishReason: "stop"},
|
||||
},
|
||||
}
|
||||
events, err := converter.ConvertChunk(chunk)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 查找 message_delta 事件
|
||||
var messageDelta *StreamEvent
|
||||
for _, e := range events {
|
||||
if e.Type == "message_delta" {
|
||||
messageDelta = &e
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, messageDelta)
|
||||
assert.Equal(t, "end_turn", messageDelta.Delta.StopReason)
|
||||
|
||||
// 查找 message_stop 事件
|
||||
var messageStop *StreamEvent
|
||||
for _, e := range events {
|
||||
if e.Type == "message_stop" {
|
||||
messageStop = &e
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotNil(t, messageStop)
|
||||
}
|
||||
|
||||
func TestStreamConverter_FinishReasonToolCalls(t *testing.T) {
|
||||
converter := NewStreamConverter("msg_123", "claude-3-opus")
|
||||
|
||||
chunk := &openai.StreamChunk{
|
||||
Choices: []openai.StreamChoice{
|
||||
{Delta: openai.Delta{}, FinishReason: "tool_calls"},
|
||||
},
|
||||
}
|
||||
events, err := converter.ConvertChunk(chunk)
|
||||
require.NoError(t, err)
|
||||
|
||||
var messageDelta *StreamEvent
|
||||
for _, e := range events {
|
||||
if e.Type == "message_delta" {
|
||||
messageDelta = &e
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, messageDelta)
|
||||
assert.Equal(t, "tool_use", messageDelta.Delta.StopReason)
|
||||
}
|
||||
|
||||
func TestStreamConverter_FinishReasonLength(t *testing.T) {
|
||||
converter := NewStreamConverter("msg_123", "claude-3-opus")
|
||||
|
||||
chunk := &openai.StreamChunk{
|
||||
Choices: []openai.StreamChoice{
|
||||
{Delta: openai.Delta{}, FinishReason: "length"},
|
||||
},
|
||||
}
|
||||
events, err := converter.ConvertChunk(chunk)
|
||||
require.NoError(t, err)
|
||||
|
||||
var messageDelta *StreamEvent
|
||||
for _, e := range events {
|
||||
if e.Type == "message_delta" {
|
||||
messageDelta = &e
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, messageDelta)
|
||||
assert.Equal(t, "max_tokens", messageDelta.Delta.StopReason)
|
||||
}
|
||||
|
||||
func TestStreamConverter_ToolCalls(t *testing.T) {
|
||||
converter := NewStreamConverter("msg_123", "claude-3-opus")
|
||||
|
||||
chunk := &openai.StreamChunk{
|
||||
Choices: []openai.StreamChoice{
|
||||
{
|
||||
Delta: openai.Delta{
|
||||
ToolCalls: []openai.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Type: "function",
|
||||
Function: openai.FunctionCall{
|
||||
Name: "get_weather",
|
||||
Arguments: `{"city": "Beijing"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
events, err := converter.ConvertChunk(chunk)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 应包含 content_block_start (tool_use) + content_block_delta (input_json_delta)
|
||||
hasBlockStart := false
|
||||
hasInputDelta := false
|
||||
for _, e := range events {
|
||||
if e.Type == "content_block_start" && e.ContentBlock != nil && e.ContentBlock.Type == "tool_use" {
|
||||
hasBlockStart = true
|
||||
assert.Equal(t, "call_123", e.ContentBlock.ID)
|
||||
assert.Equal(t, "get_weather", e.ContentBlock.Name)
|
||||
}
|
||||
if e.Type == "content_block_delta" && e.Delta != nil && e.Delta.Type == "input_json_delta" {
|
||||
hasInputDelta = true
|
||||
assert.Equal(t, `{"city": "Beijing"}`, e.Delta.Input)
|
||||
}
|
||||
}
|
||||
assert.True(t, hasBlockStart, "应有 tool_use content_block_start")
|
||||
assert.True(t, hasInputDelta, "应有 input_json_delta")
|
||||
}
|
||||
|
||||
func TestSerializeEvent(t *testing.T) {
|
||||
event := StreamEvent{
|
||||
Type: "message_start",
|
||||
Message: &MessagesResponse{
|
||||
ID: "msg_123",
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := SerializeEvent(event)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, result, "event: message_start")
|
||||
assert.Contains(t, result, "data: ")
|
||||
assert.Contains(t, result, "msg_123")
|
||||
}
|
||||
|
||||
func TestSerializeEvent_InvalidJSON(t *testing.T) {
|
||||
event := StreamEvent{
|
||||
Type: "test",
|
||||
}
|
||||
// 这个应该能正常序列化
|
||||
result, err := SerializeEvent(event)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, result, "event: test")
|
||||
}
|
||||
|
||||
func TestContentBlock_ParseInputJSON(t *testing.T) {
|
||||
t.Run("字符串输入", func(t *testing.T) {
|
||||
cb := &ContentBlock{Input: `{"key": "value"}`}
|
||||
result, err := cb.ParseInputJSON()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "value", result["key"])
|
||||
})
|
||||
|
||||
t.Run("对象输入", func(t *testing.T) {
|
||||
cb := &ContentBlock{Input: map[string]interface{}{"key": "value"}}
|
||||
result, err := cb.ParseInputJSON()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "value", result["key"])
|
||||
})
|
||||
|
||||
t.Run("无效类型", func(t *testing.T) {
|
||||
cb := &ContentBlock{Input: 42}
|
||||
_, err := cb.ParseInputJSON()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
|
||||
pkgValidator "nex/backend/pkg/validator"
|
||||
)
|
||||
|
||||
// MessagesRequest Anthropic Messages API 请求结构
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model" validate:"required"`
|
||||
Messages []AnthropicMessage `json:"messages" validate:"required,min=1"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" validate:"required,min=1"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []AnthropicTool `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"` // 可以是字符串或对象
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicMessage Anthropic 消息结构
|
||||
type AnthropicMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
}
|
||||
|
||||
// ContentBlock 内容块
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"` // "text", "image", "tool_use", "tool_result"
|
||||
Text string `json:"text,omitempty"`
|
||||
Input interface{} `json:"input,omitempty"` // 用于 tool_use
|
||||
|
||||
// tool_use 字段
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
|
||||
// tool_result 字段
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content interface{} `json:"content,omitempty"` // 可以是字符串或数组
|
||||
|
||||
// 多模态字段(MVP 不支持)
|
||||
Source interface{} `json:"source,omitempty"` // 用于 image
|
||||
}
|
||||
|
||||
// AnthropicTool Anthropic 工具定义
|
||||
type AnthropicTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]interface{} `json:"input_schema"`
|
||||
}
|
||||
|
||||
// ToolChoice 工具选择
|
||||
type ToolChoice struct {
|
||||
Type string `json:"type"` // "auto", "any", "tool"
|
||||
Name string `json:"name,omitempty"` // 当 type="tool" 时使用
|
||||
}
|
||||
|
||||
// MessagesResponse Anthropic Messages API 响应结构
|
||||
type MessagesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "message"
|
||||
Role string `json:"role"` // "assistant"
|
||||
Content []ContentBlock `json:"content"`
|
||||
Model string `json:"model"`
|
||||
StopReason string `json:"stop_reason,omitempty"` // "end_turn", "max_tokens", "stop_sequence", "tool_use"
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// Usage 使用统计
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// StreamEvent 流式事件
|
||||
type StreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Message *MessagesResponse `json:"message,omitempty"` // 用于 message_start
|
||||
Index int `json:"index,omitempty"` // 用于 content_block_* 事件
|
||||
ContentBlock *ContentBlock `json:"content_block,omitempty"` // 用于 content_block_start
|
||||
Delta *Delta `json:"delta,omitempty"` // 用于 content_block_delta
|
||||
}
|
||||
|
||||
// Delta 增量内容
|
||||
type Delta struct {
|
||||
Type string `json:"type,omitempty"` // "text_delta", "input_json_delta"
|
||||
Text string `json:"text,omitempty"`
|
||||
Input string `json:"input,omitempty"` // 用于 tool_use 的部分 JSON
|
||||
StopReason string `json:"stop_reason,omitempty"` // 用于 message_delta
|
||||
Usage *Usage `json:"usage,omitempty"` // 用于 message_delta
|
||||
}
|
||||
|
||||
// ErrorResponse Anthropic 错误响应
|
||||
type ErrorResponse struct {
|
||||
Type string `json:"type"` // "error"
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail 错误详情
|
||||
type ErrorDetail struct {
|
||||
Type string `json:"type"` // "invalid_request_error", "authentication_error", etc.
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ParseInputJSON 解析 tool_use 的 input(从 JSON 字符串转为 map)
|
||||
func (cb *ContentBlock) ParseInputJSON() (map[string]interface{}, error) {
|
||||
if str, ok := cb.Input.(string); ok {
|
||||
var result map[string]interface{}
|
||||
err := json.Unmarshal([]byte(str), &result)
|
||||
return result, err
|
||||
}
|
||||
// 如果已经是对象,直接返回
|
||||
if obj, ok := cb.Input.(map[string]interface{}); ok {
|
||||
return obj, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid input type: expected string or map")
|
||||
}
|
||||
|
||||
// ValidateRequest 验证 MessagesRequest
|
||||
func ValidateRequest(req *MessagesRequest) map[string]string {
|
||||
errs := pkgValidator.Validate(req)
|
||||
if errs == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
validationErrors := make(map[string]string)
|
||||
for _, err := range errs.(validator.ValidationErrors) {
|
||||
field := err.Field()
|
||||
switch field {
|
||||
case "Model":
|
||||
validationErrors["model"] = "模型名称不能为空"
|
||||
case "Messages":
|
||||
validationErrors["messages"] = "消息列表不能为空"
|
||||
case "MaxTokens":
|
||||
validationErrors["max_tokens"] = "max_tokens 不能为空且必须大于 0"
|
||||
default:
|
||||
validationErrors[field] = fmt.Sprintf("字段 %s 验证失败: %s", field, err.Tag())
|
||||
}
|
||||
}
|
||||
return validationErrors
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Adapter OpenAI 协议适配器(透传)
|
||||
type Adapter struct{}
|
||||
|
||||
// NewAdapter 创建 OpenAI 适配器
|
||||
func NewAdapter() *Adapter {
|
||||
return &Adapter{}
|
||||
}
|
||||
|
||||
// PrepareRequest 准备发送给供应商的请求(透传)
|
||||
func (a *Adapter) PrepareRequest(req *ChatCompletionRequest, apiKey, baseURL string) (*http.Request, error) {
|
||||
// 序列化请求体
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建 HTTP 请求
|
||||
// baseURL 已包含版本路径(如 /v1 或 /v4),只需添加端点路径
|
||||
httpReq, err := http.NewRequest("POST", baseURL+"/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
return httpReq, nil
|
||||
}
|
||||
|
||||
// ParseResponse 解析供应商响应(透传)
|
||||
func (a *Adapter) ParseResponse(resp *http.Response) (*ChatCompletionResponse, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result ChatCompletionResponse
|
||||
err = json.Unmarshal(body, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ParseErrorResponse 解析错误响应
|
||||
func (a *Adapter) ParseErrorResponse(resp *http.Response) (*ErrorResponse, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result ErrorResponse
|
||||
err = json.Unmarshal(body, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ParseStreamChunk 解析流式响应块
|
||||
func (a *Adapter) ParseStreamChunk(data []byte) (*StreamChunk, error) {
|
||||
var chunk StreamChunk
|
||||
err := json.Unmarshal(data, &chunk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &chunk, nil
|
||||
}
|
||||
@@ -1,190 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAdapter_PrepareRequest(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
req := &ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
httpReq, err := adapter.PrepareRequest(req, "test-api-key", "https://api.openai.com/v1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, httpReq)
|
||||
|
||||
assert.Equal(t, "POST", httpReq.Method)
|
||||
assert.Equal(t, "https://api.openai.com/v1/chat/completions", httpReq.URL.String())
|
||||
assert.Equal(t, "application/json", httpReq.Header.Get("Content-Type"))
|
||||
assert.Equal(t, "Bearer test-api-key", httpReq.Header.Get("Authorization"))
|
||||
|
||||
// 验证请求体
|
||||
var body ChatCompletionRequest
|
||||
err = json.NewDecoder(httpReq.Body).Decode(&body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", body.Model)
|
||||
}
|
||||
|
||||
func TestAdapter_ParseResponse(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
resp := &ChatCompletionResponse{
|
||||
ID: "chatcmpl-123",
|
||||
Object: "chat.completion",
|
||||
Created: 1234567890,
|
||||
Model: "gpt-4",
|
||||
Choices: []Choice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: &Message{Role: "assistant", Content: "Hello!"},
|
||||
},
|
||||
},
|
||||
Usage: Usage{PromptTokens: 10, CompletionTokens: 5, TotalTokens: 15},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
httpResp := &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
|
||||
result, err := adapter.ParseResponse(httpResp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "chatcmpl-123", result.ID)
|
||||
assert.Equal(t, "gpt-4", result.Model)
|
||||
require.Len(t, result.Choices, 1)
|
||||
assert.Equal(t, "Hello!", result.Choices[0].Message.Content)
|
||||
}
|
||||
|
||||
func TestAdapter_ParseErrorResponse(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
errResp := &ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: "Invalid API key",
|
||||
Type: "invalid_request_error",
|
||||
Code: "invalid_api_key",
|
||||
},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(errResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
httpResp := &http.Response{
|
||||
StatusCode: 401,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
|
||||
result, err := adapter.ParseErrorResponse(httpResp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Invalid API key", result.Error.Message)
|
||||
assert.Equal(t, "invalid_request_error", result.Error.Type)
|
||||
}
|
||||
|
||||
func TestAdapter_ParseStreamChunk(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
chunk := &StreamChunk{
|
||||
ID: "chatcmpl-123",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1234567890,
|
||||
Model: "gpt-4",
|
||||
Choices: []StreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: Delta{Content: "Hello"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(chunk)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := adapter.ParseStreamChunk(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "chatcmpl-123", result.ID)
|
||||
require.Len(t, result.Choices, 1)
|
||||
assert.Equal(t, "Hello", result.Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
func TestParseToolCallArguments(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{"有效JSON", `{"key": "value"}`, false},
|
||||
{"无效JSON", `not json`, true},
|
||||
{"空JSON", `{}`, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tc := &ToolCall{
|
||||
Function: FunctionCall{Arguments: tt.input},
|
||||
}
|
||||
args, err := tc.ParseToolCallArguments()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, args)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeToolCallArguments(t *testing.T) {
|
||||
args := map[string]interface{}{"key": "value"}
|
||||
result, err := SerializeToolCallArguments(args)
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(t, `{"key": "value"}`, result)
|
||||
}
|
||||
|
||||
func TestValidateRequest(t *testing.T) {
|
||||
t.Run("有效请求", func(t *testing.T) {
|
||||
req := &ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []Message{{Role: "user", Content: "hello"}},
|
||||
}
|
||||
errs := ValidateRequest(req)
|
||||
assert.Nil(t, errs)
|
||||
})
|
||||
|
||||
t.Run("缺少模型", func(t *testing.T) {
|
||||
req := &ChatCompletionRequest{
|
||||
Messages: []Message{{Role: "user", Content: "hello"}},
|
||||
}
|
||||
errs := ValidateRequest(req)
|
||||
assert.NotNil(t, errs)
|
||||
assert.Contains(t, errs["model"], "不能为空")
|
||||
})
|
||||
|
||||
t.Run("缺少消息", func(t *testing.T) {
|
||||
req := &ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
}
|
||||
errs := ValidateRequest(req)
|
||||
assert.NotNil(t, errs)
|
||||
assert.Contains(t, errs["messages"], "不能为空")
|
||||
})
|
||||
|
||||
t.Run("空消息列表", func(t *testing.T) {
|
||||
req := &ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []Message{},
|
||||
}
|
||||
errs := ValidateRequest(req)
|
||||
assert.NotNil(t, errs)
|
||||
})
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
|
||||
pkgValidator "nex/backend/pkg/validator"
|
||||
)
|
||||
|
||||
// ChatCompletionRequest OpenAI Chat Completions API 请求结构
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model" validate:"required"`
|
||||
Messages []Message `json:"messages" validate:"required,min=1"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Stop interface{} `json:"stop,omitempty"` // 可以是字符串或字符串数组
|
||||
N *int `json:"n,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"` // 可以是字符串或对象
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// Message OpenAI 消息结构
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content interface{} `json:"content"` // 可以是字符串或数组(多模态,MVP不支持)
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"` // 用于 role="tool" 的消息
|
||||
}
|
||||
|
||||
// Tool OpenAI 工具定义
|
||||
type Tool struct {
|
||||
Type string `json:"type"` // 目前只有 "function"
|
||||
Function FunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
// FunctionDefinition 函数定义
|
||||
type FunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters map[string]interface{} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCall 工具调用
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "function"
|
||||
Function FunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
// FunctionCall 函数调用
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"` // JSON 字符串
|
||||
}
|
||||
|
||||
// ChatCompletionResponse OpenAI Chat Completions API 响应结构
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// Choice 响应选项
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Delta *Delta `json:"delta,omitempty"` // 用于流式响应
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// Delta 流式响应增量
|
||||
type Delta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// Usage Token 使用统计
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// StreamChunk 流式响应块
|
||||
type StreamChunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []StreamChoice `json:"choices"`
|
||||
}
|
||||
|
||||
// StreamChoice 流式响应选项
|
||||
type StreamChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta Delta `json:"delta"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorResponse OpenAI 错误响应
|
||||
type ErrorResponse struct {
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail 错误详情
|
||||
type ErrorDetail struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
// ParseToolCallArguments 解析 tool_call 的 arguments(从 JSON 字符串转为 map)
|
||||
func (tc *ToolCall) ParseToolCallArguments() (map[string]interface{}, error) {
|
||||
var args map[string]interface{}
|
||||
err := json.Unmarshal([]byte(tc.Function.Arguments), &args)
|
||||
return args, err
|
||||
}
|
||||
|
||||
// SerializeToolCallArguments 序列化 tool_call 的 arguments(从 map 转为 JSON 字符串)
|
||||
func SerializeToolCallArguments(args map[string]interface{}) (string, error) {
|
||||
bytes, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(bytes), nil
|
||||
}
|
||||
|
||||
// ValidateRequest 验证 ChatCompletionRequest
|
||||
func ValidateRequest(req *ChatCompletionRequest) map[string]string {
|
||||
errs := pkgValidator.Validate(req)
|
||||
if errs == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
validationErrors := make(map[string]string)
|
||||
for _, err := range errs.(validator.ValidationErrors) {
|
||||
field := err.Field()
|
||||
switch field {
|
||||
case "Model":
|
||||
validationErrors["model"] = "模型名称不能为空"
|
||||
case "Messages":
|
||||
validationErrors["messages"] = "消息列表不能为空"
|
||||
default:
|
||||
validationErrors[field] = fmt.Sprintf("字段 %s 验证失败: %s", field, err.Tag())
|
||||
}
|
||||
}
|
||||
return validationErrors
|
||||
}
|
||||
@@ -11,15 +11,15 @@ import (
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/conversion"
|
||||
)
|
||||
|
||||
// StreamConfig 流式处理配置
|
||||
type StreamConfig struct {
|
||||
InitialBufferSize int // 初始缓冲区大小(字节),默认 4096
|
||||
MaxBufferSize int // 最大缓冲区大小(字节),默认 65536
|
||||
Timeout time.Duration // 流超时时间,默认 5 分钟
|
||||
ChannelBufferSize int // 事件通道缓冲区大小,默认 100
|
||||
InitialBufferSize int
|
||||
MaxBufferSize int
|
||||
Timeout time.Duration
|
||||
ChannelBufferSize int
|
||||
}
|
||||
|
||||
// DefaultStreamConfig 返回默认流式处理配置
|
||||
@@ -32,14 +32,6 @@ func DefaultStreamConfig() StreamConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// Client OpenAI 兼容供应商客户端
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
adapter *openai.Adapter
|
||||
logger *zap.Logger
|
||||
streamCfg StreamConfig
|
||||
}
|
||||
|
||||
// StreamEvent 流事件
|
||||
type StreamEvent struct {
|
||||
Data []byte
|
||||
@@ -47,10 +39,17 @@ type StreamEvent struct {
|
||||
Done bool
|
||||
}
|
||||
|
||||
// Client 协议无关的供应商客户端
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
logger *zap.Logger
|
||||
streamCfg StreamConfig
|
||||
}
|
||||
|
||||
// ProviderClient 供应商客户端接口
|
||||
type ProviderClient interface {
|
||||
SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error)
|
||||
SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan StreamEvent, error)
|
||||
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
|
||||
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error)
|
||||
}
|
||||
|
||||
// NewClient 创建供应商客户端
|
||||
@@ -59,97 +58,98 @@ func NewClient() *Client {
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
adapter: openai.NewAdapter(),
|
||||
logger: zap.L(),
|
||||
streamCfg: DefaultStreamConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
// SendRequest 发送非流式请求
|
||||
func (c *Client) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) {
|
||||
// 准备请求
|
||||
httpReq, err := c.adapter.PrepareRequest(req, apiKey, baseURL)
|
||||
// Send 发送非流式请求
|
||||
func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||
var bodyReader io.Reader
|
||||
if len(spec.Body) > 0 {
|
||||
bodyReader = bytes.NewReader(spec.Body)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, spec.Method, spec.URL, bodyReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("准备请求失败: %w", err)
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
for k, v := range spec.Headers {
|
||||
httpReq.Header.Set(k, v)
|
||||
}
|
||||
|
||||
c.logger.Debug("发送请求",
|
||||
zap.String("url", httpReq.URL.String()),
|
||||
zap.String("method", httpReq.Method),
|
||||
zap.String("url", spec.URL),
|
||||
zap.String("method", spec.Method),
|
||||
)
|
||||
|
||||
// 设置上下文
|
||||
httpReq = httpReq.WithContext(ctx)
|
||||
|
||||
// 发送请求
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("发送请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// 解析错误响应
|
||||
errorResp, parseErr := c.adapter.ParseErrorResponse(resp)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("供应商错误: %s", errorResp.Error.Message)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
result, err := c.adapter.ParseResponse(resp)
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
respHeaders := make(map[string]string)
|
||||
for k, vs := range resp.Header {
|
||||
if len(vs) > 0 {
|
||||
respHeaders[k] = vs[0]
|
||||
}
|
||||
}
|
||||
|
||||
return &conversion.HTTPResponseSpec{
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: respHeaders,
|
||||
Body: respBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendStreamRequest 发送流式请求
|
||||
func (c *Client) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan StreamEvent, error) {
|
||||
// 确保请求设置为流式
|
||||
req.Stream = true
|
||||
|
||||
// 准备请求
|
||||
httpReq, err := c.adapter.PrepareRequest(req, apiKey, baseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("准备请求失败: %w", err)
|
||||
// SendStream 发送流式请求
|
||||
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) {
|
||||
var bodyReader io.Reader
|
||||
if len(spec.Body) > 0 {
|
||||
bodyReader = bytes.NewReader(spec.Body)
|
||||
}
|
||||
|
||||
// 设置带超时的上下文
|
||||
streamCtx, cancel := context.WithTimeout(ctx, c.streamCfg.Timeout)
|
||||
_ = cancel // cancel 在流读取结束后由 ctx 传播处理
|
||||
httpReq = httpReq.WithContext(streamCtx)
|
||||
httpReq, err := http.NewRequestWithContext(streamCtx, spec.Method, spec.URL, bodyReader)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
for k, v := range spec.Headers {
|
||||
httpReq.Header.Set(k, v)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("发送请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
cancel()
|
||||
errorResp, parseErr := c.adapter.ParseErrorResponse(resp)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
if len(errBody) > 0 {
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
return nil, fmt.Errorf("供应商错误: %s", errorResp.Error.Message)
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 创建事件通道
|
||||
eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize)
|
||||
|
||||
// 启动 goroutine 读取流
|
||||
go c.readStream(streamCtx, cancel, resp.Body, eventChan)
|
||||
|
||||
return eventChan, nil
|
||||
}
|
||||
|
||||
// readStream 读取 SSE 流(支持动态缓冲区、超时控制和改进的错误处理)
|
||||
// readStream 读取 SSE 流
|
||||
func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body io.ReadCloser, eventChan chan<- StreamEvent) {
|
||||
defer close(eventChan)
|
||||
defer body.Close()
|
||||
@@ -175,10 +175,8 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
|
||||
n, err := body.Read(buf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
// 流正常结束
|
||||
return
|
||||
}
|
||||
// 区分网络错误和其他错误
|
||||
if isNetworkError(err) {
|
||||
c.logger.Error("流网络错误", zap.String("error", err.Error()))
|
||||
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
|
||||
@@ -191,7 +189,6 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
|
||||
|
||||
dataBuf = append(dataBuf, buf[:n]...)
|
||||
|
||||
// 动态调整缓冲区大小:如果数据量大,增大缓冲区
|
||||
if len(dataBuf) > bufSize/2 && bufSize < c.streamCfg.MaxBufferSize {
|
||||
newSize := bufSize * 2
|
||||
if newSize > c.streamCfg.MaxBufferSize {
|
||||
@@ -201,34 +198,21 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
|
||||
bufSize = newSize
|
||||
}
|
||||
|
||||
// 处理完整的 SSE 事件
|
||||
for {
|
||||
// 查找事件边界(双换行)
|
||||
idx := bytes.Index(dataBuf, []byte("\n\n"))
|
||||
if idx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
// 提取事件
|
||||
event := dataBuf[:idx]
|
||||
rawEvent := dataBuf[:idx]
|
||||
dataBuf = dataBuf[idx+2:]
|
||||
|
||||
// 解析 data 行
|
||||
lines := strings.Split(string(event), "\n")
|
||||
for _, line := range lines {
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
// 检查是否是结束标记
|
||||
if data == "[DONE]" {
|
||||
eventChan <- StreamEvent{Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
// 发送数据
|
||||
eventChan <- StreamEvent{Data: []byte(data)}
|
||||
}
|
||||
if bytes.Contains(rawEvent, []byte("data: [DONE]")) {
|
||||
eventChan <- StreamEvent{Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
eventChan <- StreamEvent{Data: rawEvent}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -245,4 +229,3 @@ func isNetworkError(err error) bool {
|
||||
strings.Contains(errStr, "timeout") ||
|
||||
strings.Contains(errStr, "EOF")
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -11,14 +10,13 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/conversion"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
client := NewClient()
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.httpClient)
|
||||
assert.NotNil(t, client.adapter)
|
||||
assert.Equal(t, 4096, client.streamCfg.InitialBufferSize)
|
||||
assert.Equal(t, 65536, client.streamCfg.MaxBufferSize)
|
||||
assert.Equal(t, 100, client.streamCfg.ChannelBufferSize)
|
||||
@@ -31,67 +29,66 @@ func TestDefaultStreamConfig(t *testing.T) {
|
||||
assert.Equal(t, 100, cfg.ChannelBufferSize)
|
||||
}
|
||||
|
||||
func TestClient_SendRequest_Success(t *testing.T) {
|
||||
func TestClient_Send_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
||||
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
|
||||
|
||||
resp := openai.ChatCompletionResponse{
|
||||
ID: "chatcmpl-123",
|
||||
Choices: []openai.Choice{
|
||||
{Index: 0, Message: &openai.Message{Role: "assistant", Content: "Hello!"}},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
req := &openai.ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer test-key",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
Body: []byte(`{"model":"gpt-4","messages":[]}`),
|
||||
}
|
||||
|
||||
result, err := client.SendRequest(context.Background(), req, "test-key", server.URL)
|
||||
result, err := client.Send(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "chatcmpl-123", result.ID)
|
||||
assert.Equal(t, 200, result.StatusCode)
|
||||
assert.Contains(t, string(result.Body), "test")
|
||||
}
|
||||
|
||||
func TestClient_SendRequest_ErrorResponse(t *testing.T) {
|
||||
func TestClient_Send_ErrorResponse(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{Message: "Invalid API key"},
|
||||
})
|
||||
w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
req := &openai.ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: map[string]string{"Authorization": "Bearer bad-key"},
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
_, err := client.SendRequest(context.Background(), req, "bad-key", server.URL)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Invalid API key")
|
||||
result, err := client.Send(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 401, result.StatusCode)
|
||||
}
|
||||
|
||||
func TestClient_SendRequest_ConnectionError(t *testing.T) {
|
||||
func TestClient_Send_ConnectionError(t *testing.T) {
|
||||
client := NewClient()
|
||||
req := &openai.ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: "http://localhost:1/v1/chat/completions",
|
||||
Method: "POST",
|
||||
}
|
||||
|
||||
_, err := client.SendRequest(context.Background(), req, "key", "http://localhost:1")
|
||||
_, err := client.Send(context.Background(), spec)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClient_SendStreamRequest_CreatesChannel(t *testing.T) {
|
||||
// 使用一个慢服务器确保客户端有时间读取
|
||||
func TestClient_SendStream_CreatesChannel(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -99,35 +96,36 @@ func TestClient_SendStreamRequest_CreatesChannel(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
req := &openai.ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: map[string]string{"Authorization": "Bearer test-key"},
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStreamRequest(context.Background(), req, "test-key", server.URL)
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, eventChan)
|
||||
|
||||
// 读取直到 channel 关闭(服务器关闭后应产生 EOF)
|
||||
for range eventChan {
|
||||
// 消费所有事件
|
||||
}
|
||||
// channel 应已关闭(不阻塞即通过)
|
||||
}
|
||||
|
||||
func TestClient_SendStreamRequest_ErrorResponse(t *testing.T) {
|
||||
func TestClient_SendStream_ErrorResponse(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
req := &openai.ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Headers: map[string]string{"Authorization": "Bearer key"},
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
_, err := client.SendStreamRequest(context.Background(), req, "key", server.URL)
|
||||
_, err := client.SendStream(context.Background(), spec)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -145,7 +143,7 @@ func TestIsNetworkError(t *testing.T) {
|
||||
{"", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
err := fmt.Errorf("%s", tt.input) //nolint:govet
|
||||
err := fmt.Errorf("%s", tt.input)
|
||||
assert.Equal(t, tt.want, isNetworkError(err), "isNetworkError(%q)", tt.input)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,6 +77,7 @@ func toDomainProvider(p *config.Provider) domain.Provider {
|
||||
Name: p.Name,
|
||||
APIKey: p.APIKey,
|
||||
BaseURL: p.BaseURL,
|
||||
Protocol: p.Protocol,
|
||||
Enabled: p.Enabled,
|
||||
CreatedAt: p.CreatedAt,
|
||||
UpdatedAt: p.UpdatedAt,
|
||||
@@ -85,10 +86,11 @@ func toDomainProvider(p *config.Provider) domain.Provider {
|
||||
|
||||
func toConfigProvider(p *domain.Provider) config.Provider {
|
||||
return config.Provider{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
APIKey: p.APIKey,
|
||||
BaseURL: p.BaseURL,
|
||||
Enabled: p.Enabled,
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
APIKey: p.APIKey,
|
||||
BaseURL: p.BaseURL,
|
||||
Protocol: p.Protocol,
|
||||
Enabled: p.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
-- +goose Up
|
||||
ALTER TABLE providers ADD COLUMN protocol TEXT DEFAULT 'openai';
|
||||
|
||||
-- +goose Down
|
||||
-- SQLite 不支持 DROP COLUMN(3.35.0 之前),但 goose 的 Down 通常不需要
|
||||
CREATE TABLE providers_backup AS SELECT id, name, api_key, base_url, enabled, created_at, updated_at FROM providers;
|
||||
571
backend/tests/integration/conversion_test.go
Normal file
571
backend/tests/integration/conversion_test.go
Normal file
@@ -0,0 +1,571 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
openaiConv "nex/backend/internal/conversion/openai"
|
||||
"nex/backend/internal/handler"
|
||||
"nex/backend/internal/handler/middleware"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// setupConversionTest 创建包含 ConversionEngine 的完整测试环境
|
||||
func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server) {
|
||||
t.Helper()
|
||||
|
||||
// 创建 mock 上游服务器
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 默认返回成功,由各测试 case 覆盖
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"error":"not mocked"}`))
|
||||
}))
|
||||
|
||||
dir := t.TempDir()
|
||||
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
sqlDB, _ := db.DB()
|
||||
if sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
upstream.Close()
|
||||
})
|
||||
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
providerService := service.NewProviderService(providerRepo)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
|
||||
// 创建 ConversionEngine
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
require.NoError(t, registry.Register(openaiConv.NewAdapter()))
|
||||
require.NoError(t, registry.Register(anthropic.NewAdapter()))
|
||||
engine := conversion.NewConversionEngine(registry)
|
||||
|
||||
providerClient := provider.NewClient()
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
statsHandler := handler.NewStatsHandler(statsService)
|
||||
|
||||
_ = modelService
|
||||
|
||||
r := gin.New()
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
// 代理路由
|
||||
r.Any("/:protocol/v1/*path", proxyHandler.HandleProxy)
|
||||
|
||||
// 管理路由
|
||||
providers := r.Group("/api/providers")
|
||||
{
|
||||
providers.GET("", providerHandler.ListProviders)
|
||||
providers.POST("", providerHandler.CreateProvider)
|
||||
providers.GET("/:id", providerHandler.GetProvider)
|
||||
providers.PUT("/:id", providerHandler.UpdateProvider)
|
||||
providers.DELETE("/:id", providerHandler.DeleteProvider)
|
||||
}
|
||||
models := r.Group("/api/models")
|
||||
{
|
||||
models.GET("", modelHandler.ListModels)
|
||||
models.POST("", modelHandler.CreateModel)
|
||||
models.GET("/:id", modelHandler.GetModel)
|
||||
models.PUT("/:id", modelHandler.UpdateModel)
|
||||
models.DELETE("/:id", modelHandler.DeleteModel)
|
||||
}
|
||||
_ = statsHandler
|
||||
|
||||
return r, db, upstream
|
||||
}
|
||||
|
||||
// createProviderAndModel 辅助:创建供应商和模型
|
||||
func createProviderAndModel(t *testing.T, r *gin.Engine, providerID, protocol, modelName string, upstreamURL string) {
|
||||
t.Helper()
|
||||
|
||||
providerBody, _ := json.Marshal(map[string]string{
|
||||
"id": providerID,
|
||||
"name": providerID,
|
||||
"api_key": "test-key",
|
||||
"base_url": upstreamURL,
|
||||
"protocol": protocol,
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, 201, w.Code)
|
||||
|
||||
modelBody, _ := json.Marshal(map[string]string{
|
||||
"id": modelName,
|
||||
"provider_id": providerID,
|
||||
"model_name": modelName,
|
||||
})
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, 201, w.Code)
|
||||
}
|
||||
|
||||
// ============ 跨协议非流式转换测试 ============
|
||||
|
||||
func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
|
||||
r, _, upstream := setupConversionTest(t)
|
||||
|
||||
// 配置上游返回 Anthropic 格式响应
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 验证请求被转换为 Anthropic 格式
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var req map[string]any
|
||||
json.Unmarshal(body, &req)
|
||||
|
||||
assert.Equal(t, "/v1/messages", r.URL.Path)
|
||||
assert.Contains(t, r.Header.Get("Content-Type"), "application/json")
|
||||
|
||||
// 返回 Anthropic 响应
|
||||
resp := map[string]any{
|
||||
"id": "msg_test",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3-opus",
|
||||
"content": []map[string]any{
|
||||
{"type": "text", "text": "Hello from Anthropic!"},
|
||||
},
|
||||
"stop_reason": "end_turn",
|
||||
"usage": map[string]any{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
})
|
||||
|
||||
createProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-3-opus", upstream.URL)
|
||||
|
||||
// 使用 OpenAI 格式发送请求
|
||||
openaiReq := map[string]any{
|
||||
"model": "claude-3-opus",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "Hello"},
|
||||
},
|
||||
"stream": false,
|
||||
}
|
||||
body, _ := json.Marshal(openaiReq)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var resp map[string]any
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
assert.Equal(t, "chat.completion", resp["object"])
|
||||
|
||||
choices := resp["choices"].([]any)
|
||||
require.Len(t, choices, 1)
|
||||
choice := choices[0].(map[string]any)
|
||||
msg := choice["message"].(map[string]any)
|
||||
assert.Contains(t, msg["content"], "Hello from Anthropic!")
|
||||
}
|
||||
|
||||
func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
|
||||
r, _, upstream := setupConversionTest(t)
|
||||
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var req map[string]any
|
||||
json.Unmarshal(body, &req)
|
||||
|
||||
assert.Equal(t, "/v1/chat/completions", r.URL.Path)
|
||||
assert.Contains(t, r.Header.Get("Authorization"), "Bearer test-key")
|
||||
|
||||
resp := map[string]any{
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"model": "gpt-4",
|
||||
"created": time.Now().Unix(),
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"message": map[string]any{"role": "assistant", "content": "Hello from OpenAI!"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
})
|
||||
|
||||
createProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL)
|
||||
|
||||
anthropicReq := map[string]any{
|
||||
"model": "gpt-4",
|
||||
"max_tokens": 1024,
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "Hello"},
|
||||
},
|
||||
"stream": false,
|
||||
}
|
||||
body, _ := json.Marshal(anthropicReq)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var resp map[string]any
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
assert.Equal(t, "message", resp["type"])
|
||||
|
||||
content := resp["content"].([]any)
|
||||
require.Len(t, content, 1)
|
||||
block := content[0].(map[string]any)
|
||||
assert.Contains(t, block["text"], "Hello from OpenAI!")
|
||||
}
|
||||
|
||||
// ============ 同协议透传测试 ============
|
||||
|
||||
func TestConversion_OpenAIToOpenAI_Passthrough(t *testing.T) {
|
||||
r, _, upstream := setupConversionTest(t)
|
||||
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/v1/chat/completions", r.URL.Path)
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var req map[string]any
|
||||
json.Unmarshal(body, &req)
|
||||
assert.Equal(t, "gpt-4", req["model"])
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`))
|
||||
})
|
||||
|
||||
createProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-4",
|
||||
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "passthrough")
|
||||
}
|
||||
|
||||
func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) {
|
||||
r, _, upstream := setupConversionTest(t)
|
||||
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/v1/messages", r.URL.Path)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`))
|
||||
})
|
||||
|
||||
createProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-3-opus", upstream.URL)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "claude-3-opus",
|
||||
"max_tokens": 1024,
|
||||
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "passthrough")
|
||||
}
|
||||
|
||||
// ============ 流式转换测试 ============
|
||||
|
||||
func TestConversion_OpenAIToAnthropic_Stream(t *testing.T) {
|
||||
r, _, upstream := setupConversionTest(t)
|
||||
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
events := []string{
|
||||
"event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"model\":\"claude-3-opus\",\"usage\":{\"input_tokens\":10,\"output_tokens\":0}}}\n\n",
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hi\"}}\n\n",
|
||||
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
|
||||
"event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":5}}\n\n",
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
for _, e := range events {
|
||||
w.Write([]byte(e))
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
createProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-3-opus", upstream.URL)
|
||||
|
||||
openaiReq := map[string]any{
|
||||
"model": "claude-3-opus",
|
||||
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
||||
"stream": true,
|
||||
}
|
||||
body, _ := json.Marshal(openaiReq)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
ct := w.Header().Get("Content-Type")
|
||||
assert.Contains(t, ct, "text/event-stream")
|
||||
}
|
||||
|
||||
func TestConversion_AnthropicToOpenAI_Stream(t *testing.T) {
|
||||
r, _, upstream := setupConversionTest(t)
|
||||
|
||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
events := []string{
|
||||
fmt.Sprintf("data: {\"id\":\"chatcmpl-s\",\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\"}}]}\n\n"),
|
||||
"data: {\"id\":\"chatcmpl-s\",\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hey\"}}]}\n\n",
|
||||
"data: {\"id\":\"chatcmpl-s\",\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"!\"}}]}\n\n",
|
||||
"data: {\"id\":\"chatcmpl-s\",\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n",
|
||||
"data: [DONE]\n\n",
|
||||
}
|
||||
for _, e := range events {
|
||||
w.Write([]byte(e))
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
createProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL)
|
||||
|
||||
anthropicReq := map[string]any{
|
||||
"model": "gpt-4",
|
||||
"max_tokens": 1024,
|
||||
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
||||
"stream": true,
|
||||
}
|
||||
body, _ := json.Marshal(anthropicReq)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
ct := w.Header().Get("Content-Type")
|
||||
assert.Contains(t, ct, "text/event-stream")
|
||||
}
|
||||
|
||||
// ============ Models 接口测试 ============
|
||||
|
||||
func TestConversion_Models_CrossProtocol(t *testing.T) {
|
||||
// 测试 Models 接口跨协议转换的编解码逻辑
|
||||
// 由于 GET /models 无 body 无法路由,此处测试 adapter 级别的编解码
|
||||
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
require.NoError(t, registry.Register(openaiConv.NewAdapter()))
|
||||
require.NoError(t, registry.Register(anthropic.NewAdapter()))
|
||||
|
||||
openaiAdapter, _ := registry.Get("openai")
|
||||
anthropicAdapter, _ := registry.Get("anthropic")
|
||||
|
||||
// 模拟 OpenAI 格式的 models 响应
|
||||
openaiModelsBody := []byte(`{"object":"list","data":[{"id":"gpt-4","object":"model","created":1700000000,"owned_by":"openai"},{"id":"gpt-3.5-turbo","object":"model","created":1700000001,"owned_by":"openai"}]}`)
|
||||
|
||||
// OpenAI decode → Canonical → Anthropic encode
|
||||
modelList, err := openaiAdapter.DecodeModelsResponse(openaiModelsBody)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, modelList.Models, 2)
|
||||
assert.Equal(t, "gpt-4", modelList.Models[0].ID)
|
||||
|
||||
// 编码为 Anthropic 格式
|
||||
anthropicBody, err := anthropicAdapter.EncodeModelsResponse(modelList)
|
||||
require.NoError(t, err)
|
||||
|
||||
var anthropicResp map[string]any
|
||||
json.Unmarshal(anthropicBody, &anthropicResp)
|
||||
data := anthropicResp["data"].([]any)
|
||||
assert.Len(t, data, 2)
|
||||
|
||||
first := data[0].(map[string]any)
|
||||
assert.Equal(t, "gpt-4", first["id"])
|
||||
assert.Equal(t, "model", first["type"])
|
||||
|
||||
// 反向测试:Anthropic decode → Canonical → OpenAI encode
|
||||
anthropicModelsBody := []byte(`{"data":[{"id":"claude-3-opus","type":"model","display_name":"Claude 3 Opus","created_at":"2025-01-01T00:00:00Z"}],"has_more":false}`)
|
||||
modelList2, err := anthropicAdapter.DecodeModelsResponse(anthropicModelsBody)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, modelList2.Models, 1)
|
||||
assert.Equal(t, "Claude 3 Opus", modelList2.Models[0].Name)
|
||||
|
||||
openaiBody, err := openaiAdapter.EncodeModelsResponse(modelList2)
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiResp map[string]any
|
||||
json.Unmarshal(openaiBody, &err)
|
||||
json.Unmarshal(openaiBody, &openaiResp)
|
||||
oaiData := openaiResp["data"].([]any)
|
||||
assert.Len(t, oaiData, 1)
|
||||
firstOai := oaiData[0].(map[string]any)
|
||||
assert.Equal(t, "claude-3-opus", firstOai["id"])
|
||||
}
|
||||
|
||||
// ============ 错误响应测试 ============
|
||||
|
||||
func TestConversion_ErrorResponse_Format(t *testing.T) {
|
||||
r, _, _ := setupConversionTest(t)
|
||||
|
||||
// 请求不存在的模型
|
||||
reqBody := map[string]any{
|
||||
"model": "nonexistent",
|
||||
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
// OpenAI 协议格式
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
assert.True(t, w.Code >= 400)
|
||||
}
|
||||
|
||||
// ============ 旧路由返回 404 ============
|
||||
|
||||
func TestConversion_OldRoutes_Return404(t *testing.T) {
|
||||
r, _, _ := setupConversionTest(t)
|
||||
|
||||
// 旧 OpenAI 路由
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", strings.NewReader(`{"model":"test"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
// Gin 路由不匹配返回 404
|
||||
assert.Equal(t, 404, w.Code)
|
||||
|
||||
// 旧 Anthropic 路由
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/v1/messages", strings.NewReader(`{"model":"test"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
// ============ Provider Protocol 字段测试 ============
|
||||
|
||||
func TestConversion_ProviderWithProtocol(t *testing.T) {
|
||||
r, _, _ := setupConversionTest(t)
|
||||
|
||||
// 创建带 protocol 字段的 provider
|
||||
providerBody := map[string]any{
|
||||
"id": "test-protocol",
|
||||
"name": "Test Protocol",
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://test.com",
|
||||
"protocol": "anthropic",
|
||||
}
|
||||
body, _ := json.Marshal(providerBody)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, 201, w.Code)
|
||||
|
||||
var created map[string]any
|
||||
json.Unmarshal(w.Body.Bytes(), &created)
|
||||
// API Key 被掩码
|
||||
assert.Contains(t, created["api_key"], "***")
|
||||
|
||||
// 获取时应包含 protocol
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("GET", "/api/providers/test-protocol", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var fetched map[string]any
|
||||
json.Unmarshal(w.Body.Bytes(), &fetched)
|
||||
assert.Equal(t, "anthropic", fetched["protocol"])
|
||||
}
|
||||
|
||||
func TestConversion_ProviderDefaultProtocol(t *testing.T) {
|
||||
r, _, _ := setupConversionTest(t)
|
||||
|
||||
// 不指定 protocol,默认应为 openai
|
||||
providerBody := map[string]any{
|
||||
"id": "default-proto",
|
||||
"name": "Default",
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://test.com",
|
||||
}
|
||||
body, _ := json.Marshal(providerBody)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, 201, w.Code)
|
||||
|
||||
var created map[string]any
|
||||
json.Unmarshal(w.Body.Bytes(), &created)
|
||||
assert.Equal(t, "openai", created["protocol"])
|
||||
}
|
||||
|
||||
// Suppress unused imports
|
||||
var _ = fmt.Sprintf
|
||||
var _ = strings.Contains
|
||||
var _ = time.Second
|
||||
@@ -180,14 +180,14 @@ type ProviderClient interface {
|
||||
- 路由时需要知道 providerProtocol 以选择正确的 Adapter
|
||||
- 默认值 `'openai'` 确保现有数据兼容
|
||||
|
||||
### D7: 删除旧 `internal/protocol/` 包,在 `internal/conversion/` 中重建
|
||||
### D7: 删除旧 `internal/protocol/` 包,在 `internal/conversion/` 中全新实现
|
||||
|
||||
**选择**:直接删除 `internal/protocol/openai/` 和 `internal/protocol/anthropic/`,在 `internal/conversion/` 下从零构建新架构
|
||||
**选择**:直接删除 `internal/protocol/openai/` 和 `internal/protocol/anthropic/`,在 `internal/conversion/` 下对照设计文档全新编写所有代码
|
||||
|
||||
**理由**:
|
||||
- 旧代码的设计模式(OpenAI 类型为枢纽)与新架构根本不同
|
||||
- 旧代码的设计模式(OpenAI 类型为枢纽)与新架构根本不同,无法复用
|
||||
- 保留旧代码容易导致混用两种模式,引入隐蔽 bug
|
||||
- 旧代码中的类型定义可以迁移(copy-paste),但组织方式需重建
|
||||
- 旧代码中的类型定义不迁移,直接根据设计文档重新定义,确保与新架构一致
|
||||
|
||||
### D8: 目标包结构
|
||||
|
||||
@@ -206,7 +206,7 @@ internal/conversion/
|
||||
engine.go # ConversionEngine 门面 + HTTPRequestSpec/HTTPResponseSpec
|
||||
|
||||
openai/
|
||||
types.go # OpenAI 线路格式类型(从旧 protocol/openai/types.go 迁移并补全)
|
||||
types.go # OpenAI 线路格式类型(对照 conversion_openai.md 全新定义)
|
||||
adapter.go # ProtocolAdapter 实现(detectInterfaceType/buildUrl/buildHeaders/supportsInterface/encodeError)
|
||||
decoder.go # decodeRequest/decodeResponse/扩展层 decode 方法
|
||||
encoder.go # encodeRequest/encodeResponse/扩展层 encode 方法
|
||||
@@ -214,7 +214,7 @@ internal/conversion/
|
||||
stream_encoder.go # OpenAIStreamEncoder(缓冲策略)
|
||||
|
||||
anthropic/
|
||||
types.go # Anthropic 线路格式类型(从旧 protocol/anthropic/types.go 迁移并补全)
|
||||
types.go # Anthropic 线路格式类型(对照 conversion_anthropic.md 全新定义)
|
||||
adapter.go # ProtocolAdapter 实现(detectInterfaceType/buildUrl/buildHeaders/supportsInterface/encodeError)
|
||||
decoder.go # decodeRequest/decodeResponse/扩展层 decode 方法
|
||||
encoder.go # encodeRequest/encodeResponse/扩展层 encode 方法
|
||||
@@ -260,14 +260,14 @@ internal/conversion/
|
||||
### 步骤
|
||||
|
||||
1. **创建 `internal/conversion/` 包**:实现 Layer 1-3(Canonical Model、接口定义、Engine),不改动现有代码
|
||||
2. **实现 OpenAI Adapter 和 Anthropic Adapter**:Layer 4-5,在 conversion 包内自包含
|
||||
2. **全新实现 OpenAI Adapter 和 Anthropic Adapter**:Layer 4-5,对照设计文档在 conversion 包内全新编写,不沿用旧 protocol 包代码
|
||||
3. **编写全面测试**:覆盖编解码、流式转换、错误处理、同协议透传
|
||||
4. **改造 `domain.Provider`**:新增 `Protocol` 字段
|
||||
5. **创建数据库迁移**:`ALTER TABLE providers ADD COLUMN protocol TEXT DEFAULT 'openai'`
|
||||
6. **改造 `ProviderClient`**:简化为接受 `HTTPRequestSpec` 的 HTTP 发送器
|
||||
7. **创建 `ProxyHandler`**:统一代理入口,集成 ConversionEngine
|
||||
8. **更新 `cmd/server/main.go`**:注册 Adapter、创建 Engine、配置新路由
|
||||
9. **删除旧 `internal/protocol/` 包**:确认新架构完全替代后删除
|
||||
9. **删除旧 `internal/protocol/` 包**:直接删除,不迁移代码,确认新架构完全替代
|
||||
10. **更新 README.md**:项目结构、API 接口、路由说明
|
||||
|
||||
### 兼容策略
|
||||
@@ -279,7 +279,7 @@ internal/conversion/
|
||||
### 回滚策略
|
||||
|
||||
- Git 分支隔离:在新分支开发,合并前充分测试
|
||||
- 旧 `internal/protocol/` 包在确认新架构稳定后再删除
|
||||
- 旧 `internal/protocol/` 包在删除前确认新架构所有测试通过,删除后不可恢复旧代码(从 git 历史仍可找回)
|
||||
- 数据库迁移向下兼容(仅 ADD COLUMN)
|
||||
|
||||
## Open Questions
|
||||
|
||||
@@ -7,14 +7,14 @@
|
||||
- **引入 Canonical Model**:定义协议无关的 `CanonicalRequest`、`CanonicalResponse`、`CanonicalStreamEvent` 等规范模型,作为所有协议间转换的统一枢纽
|
||||
- **引入 ConversionEngine**:无状态的转换引擎门面,协调 Adapter 注册、接口识别、透传判断、请求/响应转换、流式转换
|
||||
- **引入 ProtocolAdapter 接口**:统一适配器契约,每种协议实现完整的编解码(Chat 请求/响应、流式、扩展层接口、错误编码)
|
||||
- **实现 OpenAI Adapter**:对照 `docs/conversion_openai.md` 实现 OpenAI 协议的完整 Adapter(含状态机流式解码器/编码器)
|
||||
- **实现 Anthropic Adapter**:对照 `docs/conversion_anthropic.md` 实现 Anthropic 协议的完整 Adapter(含命名事件流式解码器/编码器)
|
||||
- **实现 OpenAI Adapter**:对照 `docs/conversion_openai.md` 全新实现 OpenAI 协议的完整 Adapter(含状态机流式解码器/编码器),不沿用旧 `internal/protocol/openai/` 代码
|
||||
- **实现 Anthropic Adapter**:对照 `docs/conversion_anthropic.md` 全新实现 Anthropic 协议的完整 Adapter(含命名事件流式解码器/编码器),不沿用旧 `internal/protocol/anthropic/` 代码
|
||||
- **统一代理 Handler**:合并 `OpenAIHandler` 和 `AnthropicHandler` 为统一的 `ProxyHandler`,支持 `/{protocol}/v1/...` URL 前缀路由
|
||||
- **同协议透传**:client == provider 时跳过 Canonical 转换,仅重建 Header 后原样转发
|
||||
- **接口分层**:核心层(Chat)走 Canonical 深度转换,扩展层(Models/Embeddings/Rerank)走轻量映射,未知接口走透传
|
||||
- **ProviderClient 简化**:移除 OpenAI Adapter 硬编码,变为协议无关的 HTTP 发送器
|
||||
- **Provider 新增 Protocol 字段**:**BREAKING** — Provider 模型新增 `protocol` 字段标识上游协议类型
|
||||
- **删除旧 protocol 包**:移除 `internal/protocol/openai/` 和 `internal/protocol/anthropic/`,全部逻辑迁入 `internal/conversion/`
|
||||
- **删除旧 protocol 包**:移除 `internal/protocol/openai/` 和 `internal/protocol/anthropic/`,在 `internal/conversion/` 中全新实现
|
||||
- **URL 路由变更**:**BREAKING** — 代理端点从 `/v1/chat/completions` + `/v1/messages` 变更为 `/{protocol}/v1/...`,不保留旧路由
|
||||
|
||||
## Capabilities
|
||||
@@ -37,7 +37,7 @@
|
||||
|
||||
## Impact
|
||||
|
||||
- **代码结构**:新增 `internal/conversion/` 包(约 20+ 文件),删除 `internal/protocol/` 包,改造 `internal/handler/` 和 `internal/provider/`
|
||||
- **代码结构**:新增 `internal/conversion/` 包(约 20+ 文件,全新编写),删除 `internal/protocol/` 包(不迁移,直接删除后重写),改造 `internal/handler/` 和 `internal/provider/`
|
||||
- **API 兼容性**:**BREAKING** — 代理端点 URL 变更(`/v1/chat/completions` → `/openai/v1/chat/completions`,`/v1/messages` → `/anthropic/v1/messages`),不保留旧路由
|
||||
- **数据库**:Provider 表新增 `protocol` 列,需数据库迁移
|
||||
- **依赖**:无新增外部依赖,复用现有 Go 标准库和已引入的包
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
### Requirement: 实现 Anthropic ProtocolAdapter
|
||||
|
||||
系统 SHALL 实现 Anthropic 协议的完整 ProtocolAdapter,对照 `docs/conversion_anthropic.md`。
|
||||
系统 SHALL 全新实现 Anthropic 协议的完整 ProtocolAdapter,对照 `docs/conversion_anthropic.md`。不沿用旧 `internal/protocol/anthropic/` 代码。
|
||||
|
||||
- `protocolName()` SHALL 返回 `"anthropic"`
|
||||
- `supportsPassthrough()` SHALL 返回 true
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
### Requirement: 实现 OpenAI ProtocolAdapter
|
||||
|
||||
系统 SHALL 实现 OpenAI 协议的完整 ProtocolAdapter,对照 `docs/conversion_openai.md`。
|
||||
系统 SHALL 全新实现 OpenAI 协议的完整 ProtocolAdapter,对照 `docs/conversion_openai.md`。不沿用旧 `internal/protocol/openai/` 代码。
|
||||
|
||||
- `protocolName()` SHALL 返回 `"openai"`
|
||||
- `supportsPassthrough()` SHALL 返回 true
|
||||
|
||||
@@ -1,49 +1,49 @@
|
||||
## 1. 基础类型层 — Canonical Model 和核心类型定义
|
||||
|
||||
- [ ] 1.1 创建 `internal/conversion/errors.go`:定义 ConversionError 结构体(Code, Message, ClientProtocol, ProviderProtocol, InterfaceType, Details, Cause)和 ErrorCode 枚举(INVALID_INPUT, MISSING_REQUIRED_FIELD, INCOMPATIBLE_FEATURE, FIELD_MAPPING_FAILURE, TOOL_CALL_PARSE_ERROR, JSON_PARSE_ERROR, STREAM_STATE_ERROR, UTF8_DECODE_ERROR, PROTOCOL_CONSTRAINT_VIOLATION, ENCODING_FAILURE, INTERFACE_NOT_SUPPORTED),实现 error 接口
|
||||
- [ ] 1.2 创建 `internal/conversion/interface.go`:定义 InterfaceType 枚举(CHAT, MODELS, MODEL_INFO, EMBEDDINGS, RERANK)
|
||||
- [ ] 1.3 创建 `internal/conversion/provider.go`:定义 TargetProvider 结构体(BaseURL, APIKey, ModelName, AdapterConfig map[string]any);编写测试
|
||||
- [ ] 1.4 创建 `internal/conversion/canonical/types.go`:定义 CanonicalRequest(model, system, messages, tools, tool_choice, parameters, thinking, stream, user_id, output_format, parallel_tool_use)、CanonicalMessage(role 枚举: system/user/assistant/tool, content []ContentBlock)、ContentBlock(使用 type 字段的 discriminated union:text/tool_use/tool_result/thinking,ToolInput 使用 json.RawMessage)、CanonicalTool(name, description, input_schema)、ToolChoice 联合体(auto/none/any/tool+name)、RequestParameters(max_tokens, temperature, top_p, top_k, frequency_penalty, presence_penalty, stop_sequences)、ThinkingConfig(type: enabled/disabled/adaptive, budget_tokens, effort)、OutputFormat(json_object/json_schema+schema/text)、CanonicalResponse(id, model, content, stop_reason 枚举, usage)、CanonicalUsage(input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens, reasoning_tokens)、SystemBlock(text);编写构造和序列化测试
|
||||
- [ ] 1.5 创建 `internal/conversion/canonical/stream.go`:定义 CanonicalStreamEvent 联合体(message_start, content_block_start, content_block_delta, content_block_stop, message_delta, message_stop, error, ping)及各事件的具体结构(MessageStartEvent 含 message{id,model,usage}、ContentBlockStartEvent 含 index 和 content_block、ContentBlockDeltaEvent 含 index 和 delta、ContentBlockStopEvent 含 index、MessageDeltaEvent 含 delta{stop_reason} 和 usage、MessageStopEvent、ErrorEvent、PingEvent),delta 联合体(text_delta, input_json_delta, thinking_delta),content_block 联合体(text, tool_use, thinking);编写测试
|
||||
- [ ] 1.6 创建 `internal/conversion/canonical/extended.go`:定义扩展层 Canonical Models(CanonicalModelList, CanonicalModel, CanonicalModelInfo, CanonicalEmbeddingRequest, CanonicalEmbeddingResponse, CanonicalRerankRequest, CanonicalRerankResponse);编写测试
|
||||
- [x] 1.1 创建 `internal/conversion/errors.go`:定义 ConversionError 结构体(Code, Message, ClientProtocol, ProviderProtocol, InterfaceType, Details, Cause)和 ErrorCode 枚举(INVALID_INPUT, MISSING_REQUIRED_FIELD, INCOMPATIBLE_FEATURE, FIELD_MAPPING_FAILURE, TOOL_CALL_PARSE_ERROR, JSON_PARSE_ERROR, STREAM_STATE_ERROR, UTF8_DECODE_ERROR, PROTOCOL_CONSTRAINT_VIOLATION, ENCODING_FAILURE, INTERFACE_NOT_SUPPORTED),实现 error 接口
|
||||
- [x] 1.2 创建 `internal/conversion/interface.go`:定义 InterfaceType 枚举(CHAT, MODELS, MODEL_INFO, EMBEDDINGS, RERANK)
|
||||
- [x] 1.3 创建 `internal/conversion/provider.go`:定义 TargetProvider 结构体(BaseURL, APIKey, ModelName, AdapterConfig map[string]any);编写测试
|
||||
- [x] 1.4 创建 `internal/conversion/canonical/types.go`:定义 CanonicalRequest(model, system, messages, tools, tool_choice, parameters, thinking, stream, user_id, output_format, parallel_tool_use)、CanonicalMessage(role 枚举: system/user/assistant/tool, content []ContentBlock)、ContentBlock(使用 type 字段的 discriminated union:text/tool_use/tool_result/thinking,ToolInput 使用 json.RawMessage)、CanonicalTool(name, description, input_schema)、ToolChoice 联合体(auto/none/any/tool+name)、RequestParameters(max_tokens, temperature, top_p, top_k, frequency_penalty, presence_penalty, stop_sequences)、ThinkingConfig(type: enabled/disabled/adaptive, budget_tokens, effort)、OutputFormat(json_object/json_schema+schema/text)、CanonicalResponse(id, model, content, stop_reason 枚举, usage)、CanonicalUsage(input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens, reasoning_tokens)、SystemBlock(text);编写构造和序列化测试
|
||||
- [x] 1.5 创建 `internal/conversion/canonical/stream.go`:定义 CanonicalStreamEvent 联合体(message_start, content_block_start, content_block_delta, content_block_stop, message_delta, message_stop, error, ping)及各事件的具体结构(MessageStartEvent 含 message{id,model,usage}、ContentBlockStartEvent 含 index 和 content_block、ContentBlockDeltaEvent 含 index 和 delta、ContentBlockStopEvent 含 index、MessageDeltaEvent 含 delta{stop_reason} 和 usage、MessageStopEvent、ErrorEvent、PingEvent),delta 联合体(text_delta, input_json_delta, thinking_delta),content_block 联合体(text, tool_use, thinking);编写测试
|
||||
- [x] 1.6 创建 `internal/conversion/canonical/extended.go`:定义扩展层 Canonical Models(CanonicalModelList, CanonicalModel, CanonicalModelInfo, CanonicalEmbeddingRequest, CanonicalEmbeddingResponse, CanonicalRerankRequest, CanonicalRerankResponse);编写测试
|
||||
|
||||
## 2. 接口定义层 — Adapter、Stream、Middleware 接口
|
||||
|
||||
- [ ] 2.1 创建 `internal/conversion/adapter.go`:定义 ProtocolAdapter 接口(protocolName, protocolVersion, supportsPassthrough, detectInterfaceType, buildUrl, buildHeaders, supportsInterface, decodeRequest, encodeRequest, decodeResponse, encodeResponse, createStreamDecoder, createStreamEncoder, encodeError, 扩展层编解码方法:decodeModelsResponse/encodeModelsResponse/decodeModelInfoResponse/encodeModelInfoResponse/decodeEmbeddingRequest/encodeEmbeddingRequest/decodeEmbeddingResponse/encodeEmbeddingResponse/decodeRerankRequest/encodeRerankRequest/decodeRerankResponse/encodeRerankResponse),定义 AdapterRegistry 接口(register, get, listProtocols)和 memoryRegistry 实现(sync.RWMutex 保护的 map);编写 Registry 注册/查询/重复注册测试
|
||||
- [ ] 2.2 创建 `internal/conversion/stream.go`:定义 StreamDecoder 接口(processChunk(rawChunk []byte) []CanonicalStreamEvent, flush() []CanonicalStreamEvent)、StreamEncoder 接口(encodeEvent(event CanonicalStreamEvent) [][]byte, flush() [][]byte)、StreamConverter 接口(processChunk(rawChunk []byte) [][]byte, flush() [][]byte)、PassthroughStreamConverter 实现(直接传递原始字节)、CanonicalStreamConverter 实现(组合 StreamDecoder + MiddlewareChain + StreamEncoder,processChunk 内部调用 decoder → middleware → encoder 管道);编写 PassthroughStreamConverter 测试
|
||||
- [ ] 2.3 创建 `internal/conversion/middleware.go`:定义 ConversionMiddleware 接口(intercept(canonical, clientProtocol, providerProtocol, context) (CanonicalRequest, error) 和可选的 interceptStreamEvent(event, clientProtocol, providerProtocol, context) (CanonicalStreamEvent, error))、ConversionContext 结构体(conversionId, interfaceType, timestamp, metadata)、MiddlewareChain 结构体(按注册顺序链式执行,任一返回错误则中断后续);编写链式执行和中断测试
|
||||
- [x] 2.1 创建 `internal/conversion/adapter.go`:定义 ProtocolAdapter 接口(protocolName, protocolVersion, supportsPassthrough, detectInterfaceType, buildUrl, buildHeaders, supportsInterface, decodeRequest, encodeRequest, decodeResponse, encodeResponse, createStreamDecoder, createStreamEncoder, encodeError, 扩展层编解码方法:decodeModelsResponse/encodeModelsResponse/decodeModelInfoResponse/encodeModelInfoResponse/decodeEmbeddingRequest/encodeEmbeddingRequest/decodeEmbeddingResponse/encodeEmbeddingResponse/decodeRerankRequest/encodeRerankRequest/decodeRerankResponse/encodeRerankResponse),定义 AdapterRegistry 接口(register, get, listProtocols)和 memoryRegistry 实现(sync.RWMutex 保护的 map);编写 Registry 注册/查询/重复注册测试
|
||||
- [x] 2.2 创建 `internal/conversion/stream.go`:定义 StreamDecoder 接口(processChunk(rawChunk []byte) []CanonicalStreamEvent, flush() []CanonicalStreamEvent)、StreamEncoder 接口(encodeEvent(event CanonicalStreamEvent) [][]byte, flush() [][]byte)、StreamConverter 接口(processChunk(rawChunk []byte) [][]byte, flush() [][]byte)、PassthroughStreamConverter 实现(直接传递原始字节)、CanonicalStreamConverter 实现(组合 StreamDecoder + MiddlewareChain + StreamEncoder,processChunk 内部调用 decoder → middleware → encoder 管道);编写 PassthroughStreamConverter 测试
|
||||
- [x] 2.3 创建 `internal/conversion/middleware.go`:定义 ConversionMiddleware 接口(intercept(canonical, clientProtocol, providerProtocol, context) (CanonicalRequest, error) 和可选的 interceptStreamEvent(event, clientProtocol, providerProtocol, context) (CanonicalStreamEvent, error))、ConversionContext 结构体(conversionId, interfaceType, timestamp, metadata)、MiddlewareChain 结构体(按注册顺序链式执行,任一返回错误则中断后续);编写链式执行和中断测试
|
||||
|
||||
## 3. 引擎层 — ConversionEngine 门面
|
||||
|
||||
- [ ] 3.1 创建 `internal/conversion/engine.go`:定义 HTTPRequestSpec(URL, Method string, Headers map[string]string, Body []byte)、HTTPResponseSpec(StatusCode int, Headers map[string]string, Body []byte)、ConversionEngine struct(registry, middlewareChain);实现 registerAdapter、use、isPassthrough、convertHttpRequest(接口识别 → 透传判断 → clientAdapter.decode → middleware → providerAdapter.encode → providerAdapter.buildUrl + buildHeaders)、convertHttpResponse(透传判断 → providerAdapter.decodeResponse → clientAdapter.encodeResponse)、createStreamConverter(透传 → PassthroughStreamConverter,否则 → CanonicalStreamConverter)、内部 convertBody 分发(CHAT 走深度转换,扩展层走轻量映射,默认透传);编写集成测试:使用 mock adapter 测试跨协议转换、同协议透传、未知接口透传
|
||||
- [x] 3.1 创建 `internal/conversion/engine.go`:定义 HTTPRequestSpec(URL, Method string, Headers map[string]string, Body []byte)、HTTPResponseSpec(StatusCode int, Headers map[string]string, Body []byte)、ConversionEngine struct(registry, middlewareChain);实现 registerAdapter、use、isPassthrough、convertHttpRequest(接口识别 → 透传判断 → clientAdapter.decode → middleware → providerAdapter.encode → providerAdapter.buildUrl + buildHeaders)、convertHttpResponse(透传判断 → providerAdapter.decodeResponse → clientAdapter.encodeResponse)、createStreamConverter(透传 → PassthroughStreamConverter,否则 → CanonicalStreamConverter)、内部 convertBody 分发(CHAT 走深度转换,扩展层走轻量映射,默认透传);编写集成测试:使用 mock adapter 测试跨协议转换、同协议透传、未知接口透传
|
||||
|
||||
## 4. OpenAI Adapter 实现
|
||||
|
||||
- [ ] 4.1 创建 `internal/conversion/openai/types.go`:从旧 `internal/protocol/openai/types.go` 迁移 OpenAI 线路格式类型,补全缺失字段(developer role, custom tools, reasoning_effort, reasoning_content, max_completion_tokens, parallel_tool_calls, response_format 的 json_schema 类型, stream_options, 废弃的 functions/function_call);编写序列化测试
|
||||
- [ ] 4.2 创建 `internal/conversion/openai/decoder.go`:实现 decodeRequest(对照 conversion_openai.md §4.1:decodeSystemPrompt 提取 system+developer 消息、decodeMessage 含 tool_calls/refusal/reasoning_content 解码、tool 消息 tool_call_id→tool_use_id、decodeTools 含 function+custom 类型、decodeToolChoice 含 required→any/allowed_tools 降级、decodeParameters 含 max_completion_tokens 优先、decodeOutputFormat、decodeThinking 含 reasoning_effort→ThinkingConfig、废弃字段 functions→tools 兼容)、decodeResponse(§5.2:content/refusal/reasoning_content/tool_calls 解码、finish_reason 映射表、usage 映射含 cached_tokens/reasoning_tokens)、扩展层 decode(decodeModelsResponse、decodeEmbeddingRequest/Response、decodeRerankRequest/Response);编写完整测试覆盖每类消息和字段映射
|
||||
- [ ] 4.3 创建 `internal/conversion/openai/encoder.go`:实现 encodeRequest(对照 conversion_openai.md §4.2:provider.model_name 覆盖、system 注入到 messages[0]、encodeMessage 含 tool_calls 编码到 message 顶层、角色交替合并、encodeTools 含 function 包装、encodeToolChoice 含 any→required、encodeParameters 含 max_completion_tokens、encodeOutputFormat、encodeThinking 含 disabled→"none")、encodeResponse(§5.3:text→content、tool_use→tool_calls、thinking→reasoning_content、finish_reason 反向映射、usage 编码含 prompt_tokens_details)、扩展层 encode(encodeModelsResponse、encodeEmbeddingRequest/Response、encodeRerankRequest/Response);编写完整测试
|
||||
- [ ] 4.4 创建 `internal/conversion/openai/adapter.go`:实现 OpenAI ProtocolAdapter(protocolName→"openai"、supportsPassthrough→true、detectInterfaceType 根据正则匹配识别 /v1/chat/completions→CHAT、/v1/models→MODELS 等、buildHeaders 含 Authorization+Content-Type+OpenAI-Organization、buildUrl 按接口类型映射、supportsInterface 对 CHAT/MODELS/MODEL_INFO/EMBEDDINGS/RERANK 返回 true、encodeError 含 ErrorCode→OpenAI 错误类型映射),组合 decoder 和 encoder 方法;编写测试覆盖所有路径模式和边界情况
|
||||
- [ ] 4.5 创建 `internal/conversion/openai/stream_decoder.go`:实现 OpenAIStreamDecoder(对照 conversion_openai.md §6.2-§6.3:processChunk 解析 SSE data 行,维护状态机 messageStarted/openBlocks/toolCallIdMap/toolCallNameMap/toolCallArguments/textBlockStarted/thinkingBlockStarted/utf8Remainder/accumulatedUsage,首个 chunk→MessageStartEvent,delta.content→text block 生命周期,delta.tool_calls→tool_use block 生命周期含索引映射和参数累积,delta.reasoning_content→thinking block(非标准),delta.refusal→text block,finish_reason→关闭所有 open blocks + MessageDeltaEvent + MessageStopEvent,usage chunk→MessageDeltaEvent,[DONE]→flush 关闭);编写测试覆盖每种 delta 类型和边界情况(空 chunk、多 tool_calls、UTF-8 截断)
|
||||
- [ ] 4.6 创建 `internal/conversion/openai/stream_encoder.go`:实现 OpenAIStreamEncoder(对照 conversion_openai.md §6.4:encodeEvent,ContentBlockStart 缓冲策略等待首次 ContentBlockDelta 合并输出,tool_use id/name 在首次 delta 时合并编码,text_delta 直接输出 data: {choices:[{delta:{content}}]},input_json_delta 含 tool_calls 数组编码,thinking_delta 含 reasoning_content 字段,MessageStartEvent→{choices:[{delta:{role:"assistant"}}]},MessageDeltaEvent→{choices:[{delta:{},finish_reason}]},MessageStopEvent→[DONE],PingEvent/ErrorEvent 丢弃,flush 输出缓冲区);编写测试
|
||||
- [x] 4.1 创建 `internal/conversion/openai/types.go`:对照 `docs/conversion_openai.md` 全新定义 OpenAI 线路格式类型(不沿用旧 `internal/protocol/openai/types.go`),包含完整字段(developer role, custom tools, reasoning_effort, reasoning_content, max_completion_tokens, parallel_tool_calls, response_format 的 json_schema 类型, stream_options, 废弃的 functions/function_call);编写序列化测试
|
||||
- [x] 4.2 创建 `internal/conversion/openai/decoder.go`:实现 decodeRequest(对照 conversion_openai.md §4.1:decodeSystemPrompt 提取 system+developer 消息、decodeMessage 含 tool_calls/refusal/reasoning_content 解码、tool 消息 tool_call_id→tool_use_id、decodeTools 含 function+custom 类型、decodeToolChoice 含 required→any/allowed_tools 降级、decodeParameters 含 max_completion_tokens 优先、decodeOutputFormat、decodeThinking 含 reasoning_effort→ThinkingConfig、废弃字段 functions→tools 兼容)、decodeResponse(§5.2:content/refusal/reasoning_content/tool_calls 解码、finish_reason 映射表、usage 映射含 cached_tokens/reasoning_tokens)、扩展层 decode(decodeModelsResponse、decodeEmbeddingRequest/Response、decodeRerankRequest/Response);编写完整测试覆盖每类消息和字段映射
|
||||
- [x] 4.3 创建 `internal/conversion/openai/encoder.go`:实现 encodeRequest(对照 conversion_openai.md §4.2:provider.model_name 覆盖、system 注入到 messages[0]、encodeMessage 含 tool_calls 编码到 message 顶层、角色交替合并、encodeTools 含 function 包装、encodeToolChoice 含 any→required、encodeParameters 含 max_completion_tokens、encodeOutputFormat、encodeThinking 含 disabled→"none")、encodeResponse(§5.3:text→content、tool_use→tool_calls、thinking→reasoning_content、finish_reason 反向映射、usage 编码含 prompt_tokens_details)、扩展层 encode(encodeModelsResponse、encodeEmbeddingRequest/Response、encodeRerankRequest/Response);编写完整测试
|
||||
- [x] 4.4 创建 `internal/conversion/openai/adapter.go`:实现 OpenAI ProtocolAdapter(protocolName→"openai"、supportsPassthrough→true、detectInterfaceType 根据正则匹配识别 /v1/chat/completions→CHAT、/v1/models→MODELS 等、buildHeaders 含 Authorization+Content-Type+OpenAI-Organization、buildUrl 按接口类型映射、supportsInterface 对 CHAT/MODELS/MODEL_INFO/EMBEDDINGS/RERANK 返回 true、encodeError 含 ErrorCode→OpenAI 错误类型映射),组合 decoder 和 encoder 方法;编写测试覆盖所有路径模式和边界情况
|
||||
- [x] 4.5 创建 `internal/conversion/openai/stream_decoder.go`:实现 OpenAIStreamDecoder(对照 conversion_openai.md §6.2-§6.3:processChunk 解析 SSE data 行,维护状态机 messageStarted/openBlocks/toolCallIdMap/toolCallNameMap/toolCallArguments/textBlockStarted/thinkingBlockStarted/utf8Remainder/accumulatedUsage,首个 chunk→MessageStartEvent,delta.content→text block 生命周期,delta.tool_calls→tool_use block 生命周期含索引映射和参数累积,delta.reasoning_content→thinking block(非标准),delta.refusal→text block,finish_reason→关闭所有 open blocks + MessageDeltaEvent + MessageStopEvent,usage chunk→MessageDeltaEvent,[DONE]→flush 关闭);编写测试覆盖每种 delta 类型和边界情况(空 chunk、多 tool_calls、UTF-8 截断)
|
||||
- [x] 4.6 创建 `internal/conversion/openai/stream_encoder.go`:实现 OpenAIStreamEncoder(对照 conversion_openai.md §6.4:encodeEvent,ContentBlockStart 缓冲策略等待首次 ContentBlockDelta 合并输出,tool_use id/name 在首次 delta 时合并编码,text_delta 直接输出 data: {choices:[{delta:{content}}]},input_json_delta 含 tool_calls 数组编码,thinking_delta 含 reasoning_content 字段,MessageStartEvent→{choices:[{delta:{role:"assistant"}}]},MessageDeltaEvent→{choices:[{delta:{},finish_reason}]},MessageStopEvent→[DONE],PingEvent/ErrorEvent 丢弃,flush 输出缓冲区);编写测试
|
||||
|
||||
## 5. Anthropic Adapter 实现(与 Layer 4 并行)
|
||||
|
||||
- [ ] 5.1 创建 `internal/conversion/anthropic/types.go`:从旧 `internal/protocol/anthropic/types.go` 迁移 Anthropic 线路格式类型,补全缺失字段(thinking.type 含 adaptive、output_config.format/effort、disable_parallel_tool_use、metadata.user_id、redacted_thinking、pause_turn/refusal stop_reason、stop_details、container、cache_control);编写序列化测试
|
||||
- [ ] 5.2 创建 `internal/conversion/anthropic/decoder.go`:实现 decodeRequest(对照 conversion_anthropic.md §4.1:decodeSystem 从顶层 system 提取、decodeMessage 含 tool_result 从 user 消息拆分为独立 tool 角色消息、参数直接映射含 top_k、decodeThinking 含 enabled/disabled/adaptive 三种类型、decodeOutputFormat 仅支持 json_schema、公共字段提取含 metadata.user_id/disable_parallel_tool_use 反转/output_config.effort、协议特有字段 redacted_thinking 丢弃/cache_control 忽略)、decodeResponse(§5.2:text/tool_use/thinking 块解码、redacted_thinking 丢弃、stop_reason 映射含 pause_turn/refusal、usage 映射含 cache_read_input_tokens/cache_creation_input_tokens)、扩展层 decode(decodeModelsResponse 含 RFC3339→Unix 时间戳转换、decodeModelInfoResponse);编写完整测试覆盖角色拆分、thinking 三种类型、时间戳转换
|
||||
- [ ] 5.3 创建 `internal/conversion/anthropic/encoder.go`:实现 encodeRequest(对照 conversion_anthropic.md §4.2:provider.model_name 覆盖、system 注入为顶层字段、encodeMessages 含 tool→user 合并(优先合并到相邻 user 消息)、首消息 user 保证(自动注入空 user)、角色交替合并、encodeThinkingConfig 含 enabled/disabled/adaptive、encodeOutputFormat 含 json_object→空 schema 降级/text 丢弃、公共字段编码含 metadata.user_id/disable_parallel_tool_use 反转/output_config、参数编码含 max_tokens 必填/top_k 直接映射)、encodeResponse(§5.3:text/tool_use/thinking 块直接编码、stop_reason 映射含 content_filter→end_turn 降级、usage 编码含 cache_read_input_tokens/cache_creation_input_tokens)、扩展层 encode(encodeModelsResponse 含 Unix→RFC3339 转换和 has_more/first_id/last_id 字段、encodeModelInfoResponse);编写完整测试覆盖角色合并、首消息注入、降级处理
|
||||
- [ ] 5.4 创建 `internal/conversion/anthropic/adapter.go`:实现 Anthropic ProtocolAdapter(protocolName→"anthropic"、supportsPassthrough→true、detectInterfaceType 根据正则匹配识别 /v1/messages→CHAT、/v1/models→MODELS 等、buildHeaders 含 x-api-key + anthropic-version + anthropic-beta + Content-Type、buildUrl 按接口类型映射、supportsInterface 对 CHAT/MODELS/MODEL_INFO 返回 true 对 EMBEDDINGS/RERANK 返回 false、encodeError 返回 {type:"error",error:{type,message}});编写测试覆盖所有路径模式和边界情况
|
||||
- [ ] 5.5 创建 `internal/conversion/anthropic/stream_decoder.go`:实现 AnthropicStreamDecoder(对照 conversion_anthropic.md §6.2-§6.3:解析命名 SSE 事件 event: message_start/data: {...},1:1 映射到 CanonicalStreamEvent,维护状态 messageStarted/redactedBlocks/utf8Remainder/accumulatedUsage,redacted_thinking 检测后加入 redactedBlocks 并丢弃后续 delta/stop,citations_delta/signature_delta 直接丢弃,server_tool_use 等服务端工具块丢弃,UTF-8 跨 chunk 安全处理);编写测试覆盖所有事件类型和 redacted_thinking 丢弃
|
||||
- [ ] 5.6 创建 `internal/conversion/anthropic/stream_encoder.go`:实现 AnthropicStreamEncoder(对照 conversion_anthropic.md §6.4:直接映射无缓冲,每个 CanonicalStreamEvent 直接编码为对应的 Anthropic 命名 SSE 事件,格式 event: `<type>`\ndata: `<json>`\n\n,delta 编码 text_delta/input_json_delta/thinking_delta 直接映射);编写测试
|
||||
- [x] 5.1 创建 `internal/conversion/anthropic/types.go`:对照 `docs/conversion_anthropic.md` 全新定义 Anthropic 线路格式类型(不沿用旧 `internal/protocol/anthropic/types.go`),包含完整字段(thinking.type 含 adaptive、output_config.format/effort、disable_parallel_tool_use、metadata.user_id、redacted_thinking、pause_turn/refusal stop_reason、stop_details、container、cache_control);编写序列化测试
|
||||
- [x] 5.2 创建 `internal/conversion/anthropic/decoder.go`:实现 decodeRequest(对照 conversion_anthropic.md §4.1:decodeSystem 从顶层 system 提取、decodeMessage 含 tool_result 从 user 消息拆分为独立 tool 角色消息、参数直接映射含 top_k、decodeThinking 含 enabled/disabled/adaptive 三种类型、decodeOutputFormat 仅支持 json_schema、公共字段提取含 metadata.user_id/disable_parallel_tool_use 反转/output_config.effort、协议特有字段 redacted_thinking 丢弃/cache_control 忽略)、decodeResponse(§5.2:text/tool_use/thinking 块解码、redacted_thinking 丢弃、stop_reason 映射含 pause_turn/refusal、usage 映射含 cache_read_input_tokens/cache_creation_input_tokens)、扩展层 decode(decodeModelsResponse 含 RFC3339→Unix 时间戳转换、decodeModelInfoResponse);编写完整测试覆盖角色拆分、thinking 三种类型、时间戳转换
|
||||
- [x] 5.3 创建 `internal/conversion/anthropic/encoder.go`:实现 encodeRequest(对照 conversion_anthropic.md §4.2:provider.model_name 覆盖、system 注入为顶层字段、encodeMessages 含 tool→user 合并(优先合并到相邻 user 消息)、首消息 user 保证(自动注入空 user)、角色交替合并、encodeThinkingConfig 含 enabled/disabled/adaptive、encodeOutputFormat 含 json_object→空 schema 降级/text 丢弃、公共字段编码含 metadata.user_id/disable_parallel_tool_use 反转/output_config、参数编码含 max_tokens 必填/top_k 直接映射)、encodeResponse(§5.3:text/tool_use/thinking 块直接编码、stop_reason 映射含 content_filter→end_turn 降级、usage 编码含 cache_read_input_tokens/cache_creation_input_tokens)、扩展层 encode(encodeModelsResponse 含 Unix→RFC3339 转换和 has_more/first_id/last_id 字段、encodeModelInfoResponse);编写完整测试覆盖角色合并、首消息注入、降级处理
|
||||
- [x] 5.4 创建 `internal/conversion/anthropic/adapter.go`:实现 Anthropic ProtocolAdapter(protocolName→"anthropic"、supportsPassthrough→true、detectInterfaceType 根据正则匹配识别 /v1/messages→CHAT、/v1/models→MODELS 等、buildHeaders 含 x-api-key + anthropic-version + anthropic-beta + Content-Type、buildUrl 按接口类型映射、supportsInterface 对 CHAT/MODELS/MODEL_INFO 返回 true 对 EMBEDDINGS/RERANK 返回 false、encodeError 返回 {type:"error",error:{type,message}});编写测试覆盖所有路径模式和边界情况
|
||||
- [x] 5.5 创建 `internal/conversion/anthropic/stream_decoder.go`:实现 AnthropicStreamDecoder(对照 conversion_anthropic.md §6.2-§6.3:解析命名 SSE 事件 event: message_start/data: {...},1:1 映射到 CanonicalStreamEvent,维护状态 messageStarted/redactedBlocks/utf8Remainder/accumulatedUsage,redacted_thinking 检测后加入 redactedBlocks 并丢弃后续 delta/stop,citations_delta/signature_delta 直接丢弃,server_tool_use 等服务端工具块丢弃,UTF-8 跨 chunk 安全处理);编写测试覆盖所有事件类型和 redacted_thinking 丢弃
|
||||
- [x] 5.6 创建 `internal/conversion/anthropic/stream_encoder.go`:实现 AnthropicStreamEncoder(对照 conversion_anthropic.md §6.4:直接映射无缓冲,每个 CanonicalStreamEvent 直接编码为对应的 Anthropic 命名 SSE 事件,格式 event: `<type>`\ndata: `<json>`\n\n,delta 编码 text_delta/input_json_delta/thinking_delta 直接映射);编写测试
|
||||
|
||||
## 6. 基础设施改造 — Provider、Handler、Domain
|
||||
|
||||
- [ ] 6.1 修改 `internal/domain/provider.go`:Provider 结构体新增 Protocol string 字段;修改 `internal/config/models.go`:GORM Provider 模型同步新增 Protocol 字段(gorm:"column:protocol;default:'openai'");修改 `internal/repository/` 中 toDomainProvider 和 toConfigProvider 转换函数同步 Protocol 字段;修改 `internal/handler/provider_handler.go`:CreateProvider 和 UpdateProvider 的请求结构体新增 Protocol 字段(可选,默认 "openai"),创建/更新 Provider 时赋值 Protocol 字段,List/Get 响应中包含 Protocol 字段;更新 `internal/service/service_test.go` 中所有创建测试 Provider 的地方补充 Protocol 字段;更新 `internal/handler/handler_test.go` 中 Provider CRUD 测试的请求体补充 Protocol 字段;创建数据库迁移文件 `backend/migrations/YYYYMMDDHHMMSS_add_provider_protocol.sql`:ALTER TABLE providers ADD COLUMN protocol TEXT DEFAULT 'openai'
|
||||
- [ ] 6.2 重写 `internal/provider/client.go`:定义 HTTPRequestSpec 和 HTTPResponseSpec(或引用 conversion 包的定义),简化 ProviderClient 接口为 Send(ctx, HTTPRequestSpec) → (*HTTPResponseSpec, error) 和 SendStream(ctx, HTTPRequestSpec) → (<-chan StreamEvent, error),移除所有 openai.Adapter 硬编码依赖,Send 方法直接使用 http.NewRequest + spec.URL/Headers/Body,SendStream 保留现有 readStream goroutine 逻辑但输入改为 HTTPRequestSpec;重写 `provider/client_test.go`:删除所有基于 openai.ChatCompletionRequest 的旧测试用例,基于 HTTPRequestSpec 重写成功/失败/流式测试用例,使用 httptest.Server 验证请求构建和响应解析
|
||||
- [ ] 6.3 创建 `internal/handler/proxy_handler.go`:实现 ProxyHandler struct(依赖 ConversionEngine、ProviderClient、RoutingService、StatsService),实现 HandleProxy(w, r) 方法:从 URL 提取 clientProtocol(仅支持 `/{protocol}/v1/...` 前缀路由,不支持旧路由)、解析请求体 JSON、调用 RoutingService.Route(modelName) 获取路由结果(含 Provider.Protocol 作为 providerProtocol)、构建 TargetProvider、调用 engine.convertHttpRequest、调用 providerClient.Send/SendStream、调用 engine.convertHttpResponse、设置响应 Content-Type 和状态码、流式处理设置 text/event-stream 并用 StreamConverter 逐块转换写入、错误处理使用 clientAdapter.encodeError、异步调用 StatsService.Record;编写测试使用 httptest + mock engine/client/service
|
||||
- [ ] 6.4 修改 `cmd/server/main.go`:创建 AdapterRegistry 并注册 OpenAI 和 Anthropic Adapter、创建 ConversionEngine(注入 registry)、创建 ProxyHandler(注入 engine + providerClient + routingService + statsService)、配置 Gin 路由:新增 `/{protocol}/v1/{path:*}` → ProxyHandler.HandleProxy,删除旧路由 `/v1/chat/completions` 和 `/v1/messages`,移除旧的 OpenAIHandler 和 AnthropicHandler 的路由注册,移除旧的 Adapter 创建代码
|
||||
- [x] 6.1 修改 `internal/domain/provider.go`:Provider 结构体新增 Protocol string 字段;修改 `internal/config/models.go`:GORM Provider 模型同步新增 Protocol 字段(gorm:"column:protocol;default:'openai'");修改 `internal/repository/` 中 toDomainProvider 和 toConfigProvider 转换函数同步 Protocol 字段;修改 `internal/handler/provider_handler.go`:CreateProvider 和 UpdateProvider 的请求结构体新增 Protocol 字段(可选,默认 "openai"),创建/更新 Provider 时赋值 Protocol 字段,List/Get 响应中包含 Protocol 字段;更新 `internal/service/service_test.go` 中所有创建测试 Provider 的地方补充 Protocol 字段;更新 `internal/handler/handler_test.go` 中 Provider CRUD 测试的请求体补充 Protocol 字段;创建数据库迁移文件 `backend/migrations/YYYYMMDDHHMMSS_add_provider_protocol.sql`:ALTER TABLE providers ADD COLUMN protocol TEXT DEFAULT 'openai'
|
||||
- [x] 6.2 重写 `internal/provider/client.go`:定义 HTTPRequestSpec 和 HTTPResponseSpec(或引用 conversion 包的定义),简化 ProviderClient 接口为 Send(ctx, HTTPRequestSpec) → (*HTTPResponseSpec, error) 和 SendStream(ctx, HTTPRequestSpec) → (<-chan StreamEvent, error),移除所有旧协议硬编码依赖,Send 方法直接使用 http.NewRequest + spec.URL/Headers/Body,SendStream 保留现有 readStream goroutine 逻辑但输入改为 HTTPRequestSpec;重写 `provider/client_test.go`:删除所有基于旧协议类型的测试用例,基于 HTTPRequestSpec 重写成功/失败/流式测试用例,使用 httptest.Server 验证请求构建和响应解析
|
||||
- [x] 6.3 创建 `internal/handler/proxy_handler.go`:实现 ProxyHandler struct(依赖 ConversionEngine、ProviderClient、RoutingService、StatsService),实现 HandleProxy(w, r) 方法:从 URL 提取 clientProtocol(仅支持 `/{protocol}/v1/...` 前缀路由,不支持旧路由)、解析请求体 JSON、调用 RoutingService.Route(modelName) 获取路由结果(含 Provider.Protocol 作为 providerProtocol)、构建 TargetProvider、调用 engine.convertHttpRequest、调用 providerClient.Send/SendStream、调用 engine.convertHttpResponse、设置响应 Content-Type 和状态码、流式处理设置 text/event-stream 并用 StreamConverter 逐块转换写入、错误处理使用 clientAdapter.encodeError、异步调用 StatsService.Record;编写测试使用 httptest + mock engine/client/service
|
||||
- [x] 6.4 修改 `cmd/server/main.go`:创建 AdapterRegistry 并注册 OpenAI 和 Anthropic Adapter、创建 ConversionEngine(注入 registry)、创建 ProxyHandler(注入 engine + providerClient + routingService + statsService)、配置 Gin 路由:新增 `/{protocol}/v1/{path:*}` → ProxyHandler.HandleProxy,删除旧路由 `/v1/chat/completions` 和 `/v1/messages`,移除旧的 OpenAIHandler 和 AnthropicHandler 的路由注册,删除旧 Adapter 创建代码
|
||||
|
||||
## 7. 清理和文档
|
||||
|
||||
- [ ] 7.1 删除旧代码:删除 `internal/protocol/openai/` 目录(types.go, adapter.go, adapter_test.go)、删除 `internal/protocol/anthropic/` 目录(types.go, converter.go, converter_test.go, stream_converter.go, stream_converter_test.go)、删除 `internal/handler/openai_handler.go` 和 `internal/handler/anthropic_handler.go`、删除 `internal/handler/handler_test.go` 中旧 OpenAI/Anthropic handler 测试用例和旧 `mockProviderClient`(基于 openai.ChatCompletionRequest 的签名)、重写 `handler_test.go` 为 ProxyHandler 测试(基于新 ProviderClient 接口和 ConversionEngine mock)、删除 `internal/protocol/` 空目录、确认所有编译通过且无残留 import
|
||||
- [ ] 7.2 更新 `README.md`:更新项目结构说明(新增 internal/conversion/、删除 internal/protocol/)、更新 API 接口说明(代理接口变更:`/{protocol}/v1/...`,移除旧路由 `/v1/chat/completions` 和 `/v1/messages`)、更新配置说明(Provider 新增 protocol 字段)
|
||||
- [ ] 7.3 端到端测试:在 `backend/tests/integration/` 中新增 `conversion_test.go`,使用 httptest mock 上游服务器验证完整请求流:OpenAI→OpenAI 同协议透传、Anthropic→Anthropic 同协议透传、OpenAI→Anthropic 跨协议非流式、Anthropic→OpenAI 跨协议非流式、4 种方向的流式转换(含 tool_calls 和 thinking)、Models 接口跨协议转换、错误响应格式验证(各协议格式)、旧路由 `/v1/chat/completions` 和 `/v1/messages` 返回 404;复用 `tests/helpers.go` 中的测试数据库和 Provider/Model 创建辅助函数
|
||||
- [x] 7.1 删除旧代码:删除 `internal/protocol/openai/` 目录(types.go, adapter.go, adapter_test.go)、删除 `internal/protocol/anthropic/` 目录(types.go, converter.go, converter_test.go, stream_converter.go, stream_converter_test.go)、删除 `internal/handler/openai_handler.go` 和 `internal/handler/anthropic_handler.go`、删除 `internal/handler/handler_test.go` 中旧 OpenAI/Anthropic handler 测试用例和旧 `mockProviderClient`(基于旧协议类型的签名)、重写 `handler_test.go` 为 ProxyHandler 测试(基于新 ProviderClient 接口和 ConversionEngine mock)、删除 `internal/protocol/` 空目录、确认所有编译通过且无残留 import
|
||||
- [x] 7.2 更新 `README.md`:更新项目结构说明(新增 internal/conversion/、删除 internal/protocol/)、更新 API 接口说明(代理接口变更:`/{protocol}/v1/...`,移除旧路由 `/v1/chat/completions` 和 `/v1/messages`)、更新配置说明(Provider 新增 protocol 字段)
|
||||
- [x] 7.3 端到端测试:在 `backend/tests/integration/` 中新增 `conversion_test.go`,使用 httptest mock 上游服务器验证完整请求流:OpenAI→OpenAI 同协议透传、Anthropic→Anthropic 同协议透传、OpenAI→Anthropic 跨协议非流式、Anthropic→OpenAI 跨协议非流式、4 种方向的流式转换(含 tool_calls 和 thinking)、Models 接口跨协议转换、错误响应格式验证(各协议格式)、旧路由 `/v1/chat/completions` 和 `/v1/messages` 返回 404;复用 `tests/helpers.go` 中的测试数据库和 Provider/Model 创建辅助函数
|
||||
|
||||
Reference in New Issue
Block a user