diff --git a/backend/README.md b/backend/README.md index 6488a1f..666912e 100644 --- a/backend/README.md +++ b/backend/README.md @@ -4,10 +4,14 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。 ## 功能特性 -- 支持 OpenAI 协议(`/v1/chat/completions`) -- 支持 Anthropic 协议(`/v1/messages`) +- 支持 OpenAI 协议(`/openai/v1/...`) +- 支持 Anthropic 协议(`/anthropic/v1/...`) +- 支持 Hub-and-Spoke 跨协议双向转换(OpenAI ↔ Anthropic) +- 同协议透传(零语义损失、零序列化开销) - 支持流式响应(SSE) - 支持 Function Calling / Tools +- 支持 Thinking / Reasoning +- 支持扩展层接口(Models、Embeddings、Rerank) - 多供应商配置和路由 - 用量统计 - 结构化日志(zap + lumberjack) @@ -48,19 +52,36 @@ backend/ │ │ │ ├── logging.go │ │ │ ├── recovery.go │ │ │ └── cors.go -│ │ ├── openai_handler.go -│ │ ├── anthropic_handler.go +│ │ ├── proxy_handler.go # 统一代理处理器 │ │ ├── provider_handler.go │ │ ├── model_handler.go │ │ └── stats_handler.go -│ ├── protocol/ # 协议适配器 -│ │ ├── openai/ -│ │ │ ├── types.go # 请求/响应类型 + 验证 -│ │ │ └── adapter.go # OpenAI 协议适配 -│ │ └── anthropic/ -│ │ ├── types.go # 请求/响应类型 + 验证 -│ │ ├── converter.go # 协议转换 -│ │ └── stream_converter.go # 流式转换 +│ ├── conversion/ # 协议转换引擎 +│ │ ├── canonical/ # Canonical Model +│ │ │ ├── types.go # 核心请求/响应类型 +│ │ │ ├── stream.go # 流式事件类型 +│ │ │ └── extended.go # 扩展层 Models +│ │ ├── openai/ # OpenAI 协议适配器 +│ │ │ ├── types.go +│ │ │ ├── adapter.go +│ │ │ ├── decoder.go +│ │ │ ├── encoder.go +│ │ │ ├── stream_decoder.go +│ │ │ └── stream_encoder.go +│ │ ├── anthropic/ # Anthropic 协议适配器 +│ │ │ ├── types.go +│ │ │ ├── adapter.go +│ │ │ ├── decoder.go +│ │ │ ├── encoder.go +│ │ │ ├── stream_decoder.go +│ │ │ └── stream_encoder.go +│ │ ├── adapter.go # ProtocolAdapter 接口 + Registry +│ │ ├── stream.go # StreamDecoder/Encoder/Converter +│ │ ├── middleware.go # Middleware 接口和 Chain +│ │ ├── engine.go # ConversionEngine 门面 +│ │ ├── errors.go # ConversionError +│ │ ├── interface.go # InterfaceType 枚举 +│ │ └── provider.go # TargetProvider │ ├── provider/ # 供应商客户端 │ │ └── client.go │ ├── repository/ # 数据访问层 @@ -184,10 +205,15 @@ goose -dir migrations sqlite3 ~/.nex/config.db up ### 代理接口 -#### OpenAI Chat Completions +使用 `/{protocol}/v1/{path}` URL 前缀路由: + +#### OpenAI 协议代理 ``` -POST /v1/chat/completions +POST /openai/v1/chat/completions +GET /openai/v1/models +POST /openai/v1/embeddings +POST /openai/v1/rerank ``` 请求示例: @@ -202,10 +228,11 @@ POST /v1/chat/completions } ``` -#### Anthropic Messages +#### Anthropic 协议代理 ``` -POST /v1/messages +POST /anthropic/v1/messages +GET /anthropic/v1/models ``` 请求示例: @@ -220,6 +247,8 @@ POST /v1/messages } ``` +**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销。 + ### 管理接口 #### 供应商管理 @@ -237,10 +266,15 @@ POST /v1/messages "id": "openai", "name": "OpenAI", "api_key": "sk-...", - "base_url": "https://api.openai.com/v1" + "base_url": "https://api.openai.com/v1", + "protocol": "openai" } ``` +**Protocol 字段说明:** +- `protocol` 标识上游供应商使用的协议类型,可选值:`"openai"`(默认)、`"anthropic"` +- 同协议透传时,请求体和响应体原样转发,零序列化开销 + **重要说明:** - `base_url` 应配置到 API 版本路径,不包含具体端点 - OpenAI: `https://api.openai.com/v1` diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 6492172..16bb983 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -20,6 +20,9 @@ import ( "gorm.io/gorm/logger" "nex/backend/internal/config" + "nex/backend/internal/conversion" + "nex/backend/internal/conversion/anthropic" + "nex/backend/internal/conversion/openai" "nex/backend/internal/handler" "nex/backend/internal/handler/middleware" "nex/backend/internal/provider" @@ -70,30 +73,37 @@ func main() { routingService := service.NewRoutingService(modelRepo, providerRepo) statsService := service.NewStatsService(statsRepo) - // 6. 初始化 provider client + // 6. 创建 ConversionEngine + registry := conversion.NewMemoryRegistry() + if err := registry.Register(openai.NewAdapter()); err != nil { + zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error())) + } + if err := registry.Register(anthropic.NewAdapter()); err != nil { + zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error())) + } + engine := conversion.NewConversionEngine(registry) + + // 7. 初始化 provider client providerClient := provider.NewClient() - // 7. 初始化 handler 层 - openaiHandler := handler.NewOpenAIHandler(providerClient, routingService, statsService) - anthropicHandler := handler.NewAnthropicHandler(providerClient, routingService, statsService) + // 8. 初始化 handler 层 + proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService) providerHandler := handler.NewProviderHandler(providerService) modelHandler := handler.NewModelHandler(modelService) statsHandler := handler.NewStatsHandler(statsService) - // 8. 创建 Gin 引擎 + // 9. 创建 Gin 引擎 gin.SetMode(gin.ReleaseMode) r := gin.New() - // 注册中间件(按正确顺序) r.Use(middleware.RequestID()) r.Use(middleware.Recovery(zapLogger)) r.Use(middleware.Logging(zapLogger)) r.Use(middleware.CORS()) - // 注册路由 - setupRoutes(r, openaiHandler, anthropicHandler, providerHandler, modelHandler, statsHandler) + setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler) - // 9. 启动服务器 + // 10. 启动服务器 srv := &http.Server{ Addr: formatAddr(cfg.Server.Port), Handler: r, @@ -108,7 +118,6 @@ func main() { } }() - // 等待中断信号 quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit @@ -137,12 +146,10 @@ func initDatabase(cfg *config.Config) (*gorm.DB, error) { log.Printf("警告: 启用 WAL 模式失败: %v", err) } - // 运行数据库迁移 if err := runMigrations(db); err != nil { return nil, fmt.Errorf("数据库迁移失败: %w", err) } - // 配置连接池 sqlDB, err := db.DB() if err != nil { return nil, err @@ -151,14 +158,12 @@ func initDatabase(cfg *config.Config) (*gorm.DB, error) { sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns) sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime) - // 记录连接池状态 log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v", cfg.Database.MaxIdleConns, cfg.Database.MaxOpenConns, cfg.Database.ConnMaxLifetime) return db, nil } -// runMigrations 使用 goose 执行数据库迁移 func runMigrations(db *gorm.DB) error { sqlDB, err := db.DB() if err != nil { @@ -178,18 +183,14 @@ func runMigrations(db *gorm.DB) error { return nil } -// getMigrationsDir 获取迁移文件目录路径 func getMigrationsDir() string { - // 从可执行文件位置推断迁移目录 _, filename, _, ok := runtime.Caller(0) if ok { - // cmd/server/main.go → backend/ → backend/migrations/ dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations") if abs, err := filepath.Abs(dir); err == nil { return abs } } - // 回退到相对路径 return "./migrations" } @@ -205,12 +206,9 @@ func formatAddr(port int) string { return fmt.Sprintf(":%d", port) } -func setupRoutes(r *gin.Engine, openaiHandler *handler.OpenAIHandler, anthropicHandler *handler.AnthropicHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) { - // OpenAI 协议代理 - r.POST("/v1/chat/completions", openaiHandler.HandleChatCompletions) - - // Anthropic 协议代理 - r.POST("/v1/messages", anthropicHandler.HandleMessages) +func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) { + // 统一代理入口: /{protocol}/v1/{path} + r.Any("/:protocol/v1/*path", proxyHandler.HandleProxy) // 供应商管理 API providers := r.Group("/api/providers") diff --git a/backend/internal/config/models.go b/backend/internal/config/models.go index 37916fa..9656937 100644 --- a/backend/internal/config/models.go +++ b/backend/internal/config/models.go @@ -10,6 +10,7 @@ type Provider struct { Name string `gorm:"not null" json:"name"` APIKey string `gorm:"not null" json:"api_key"` BaseURL string `gorm:"not null" json:"base_url"` + Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"` Enabled bool `gorm:"default:true" json:"enabled"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` diff --git a/backend/internal/conversion/adapter.go b/backend/internal/conversion/adapter.go new file mode 100644 index 0000000..7f5f5bd --- /dev/null +++ b/backend/internal/conversion/adapter.go @@ -0,0 +1,100 @@ +package conversion + +import ( + "fmt" + "sync" + + "nex/backend/internal/conversion/canonical" +) + +// ProtocolAdapter 协议适配器接口 +type ProtocolAdapter interface { + ProtocolName() string + ProtocolVersion() string + SupportsPassthrough() bool + + DetectInterfaceType(nativePath string) InterfaceType + BuildUrl(nativePath string, interfaceType InterfaceType) string + BuildHeaders(provider *TargetProvider) map[string]string + SupportsInterface(interfaceType InterfaceType) bool + + DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) + EncodeRequest(req *canonical.CanonicalRequest, provider *TargetProvider) ([]byte, error) + DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) + EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) + + CreateStreamDecoder() StreamDecoder + CreateStreamEncoder() StreamEncoder + + EncodeError(err *ConversionError) ([]byte, int) + + DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) + EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) + DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) + EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) + DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) + EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *TargetProvider) ([]byte, error) + DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) + EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) + DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) + EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error) + DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) + EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) +} + +// AdapterRegistry 适配器注册表接口 +type AdapterRegistry interface { + Register(adapter ProtocolAdapter) error + Get(protocolName string) (ProtocolAdapter, error) + ListProtocols() []string +} + +// memoryRegistry 基于内存的适配器注册表 +type memoryRegistry struct { + mu sync.RWMutex + adapters map[string]ProtocolAdapter +} + +// NewMemoryRegistry 创建内存注册表 +func NewMemoryRegistry() AdapterRegistry { + return &memoryRegistry{ + adapters: make(map[string]ProtocolAdapter), + } +} + +// Register 注册适配器 +func (r *memoryRegistry) Register(adapter ProtocolAdapter) error { + r.mu.Lock() + defer r.mu.Unlock() + + name := adapter.ProtocolName() + if _, exists := r.adapters[name]; exists { + return fmt.Errorf("适配器已注册: %s", name) + } + r.adapters[name] = adapter + return nil +} + +// Get 获取适配器 +func (r *memoryRegistry) Get(protocolName string) (ProtocolAdapter, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + adapter, ok := r.adapters[protocolName] + if !ok { + return nil, fmt.Errorf("未找到适配器: %s", protocolName) + } + return adapter, nil +} + +// ListProtocols 列出所有已注册协议 +func (r *memoryRegistry) ListProtocols() []string { + r.mu.RLock() + defer r.mu.RUnlock() + + protocols := make([]string, 0, len(r.adapters)) + for name := range r.adapters { + protocols = append(protocols, name) + } + return protocols +} diff --git a/backend/internal/conversion/anthropic/adapter.go b/backend/internal/conversion/anthropic/adapter.go new file mode 100644 index 0000000..350f981 --- /dev/null +++ b/backend/internal/conversion/anthropic/adapter.go @@ -0,0 +1,199 @@ +package anthropic + +import ( + "encoding/json" + "regexp" + "strings" + + "nex/backend/internal/conversion" + "nex/backend/internal/conversion/canonical" +) + +// Adapter Anthropic 协议适配器 +type Adapter struct{} + +// NewAdapter 创建 Anthropic 适配器 +func NewAdapter() *Adapter { + return &Adapter{} +} + +var modelInfoRegex = regexp.MustCompile(`^/v1/models/[^/]+$`) + +// ProtocolName 返回协议名称 +func (a *Adapter) ProtocolName() string { return "anthropic" } + +// ProtocolVersion 返回协议版本 +func (a *Adapter) ProtocolVersion() string { return "2023-06-01" } + +// SupportsPassthrough 支持同协议透传 +func (a *Adapter) SupportsPassthrough() bool { return true } + +// DetectInterfaceType 根据路径检测接口类型 +func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType { + switch { + case nativePath == "/v1/messages": + return conversion.InterfaceTypeChat + case nativePath == "/v1/models": + return conversion.InterfaceTypeModels + case modelInfoRegex.MatchString(nativePath): + return conversion.InterfaceTypeModelInfo + default: + return conversion.InterfaceTypePassthrough + } +} + +// BuildUrl 根据接口类型构建 URL +func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string { + switch interfaceType { + case conversion.InterfaceTypeChat: + return "/v1/messages" + case conversion.InterfaceTypeModels: + return "/v1/models" + default: + return nativePath + } +} + +// BuildHeaders 构建请求头 +func (a *Adapter) BuildHeaders(provider *conversion.TargetProvider) map[string]string { + headers := map[string]string{ + "x-api-key": provider.APIKey, + "anthropic-version": "2023-06-01", + "Content-Type": "application/json", + } + if v, ok := provider.AdapterConfig["anthropic_version"].(string); ok && v != "" { + headers["anthropic-version"] = v + } + if betas, ok := provider.AdapterConfig["anthropic_beta"].([]string); ok && len(betas) > 0 { + headers["anthropic-beta"] = strings.Join(betas, ",") + } else if betas, ok := provider.AdapterConfig["anthropic_beta"].([]any); ok && len(betas) > 0 { + var parts []string + for _, b := range betas { + if s, ok := b.(string); ok { + parts = append(parts, s) + } + } + if len(parts) > 0 { + headers["anthropic-beta"] = strings.Join(parts, ",") + } + } + return headers +} + +// SupportsInterface 检查是否支持接口类型 +func (a *Adapter) SupportsInterface(interfaceType conversion.InterfaceType) bool { + switch interfaceType { + case conversion.InterfaceTypeChat, + conversion.InterfaceTypeModels, + conversion.InterfaceTypeModelInfo: + return true + default: + return false + } +} + +// DecodeRequest 解码请求 +func (a *Adapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) { + return decodeRequest(raw) +} + +// EncodeRequest 编码请求 +func (a *Adapter) EncodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) { + return encodeRequest(req, provider) +} + +// DecodeResponse 解码响应 +func (a *Adapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) { + return decodeResponse(raw) +} + +// EncodeResponse 编码响应 +func (a *Adapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) { + return encodeResponse(resp) +} + +// CreateStreamDecoder 创建流式解码器 +func (a *Adapter) CreateStreamDecoder() conversion.StreamDecoder { + return NewStreamDecoder() +} + +// CreateStreamEncoder 创建流式编码器 +func (a *Adapter) CreateStreamEncoder() conversion.StreamEncoder { + return NewStreamEncoder() +} + +// EncodeError 编码错误 +func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) { + errType := string(err.Code) + statusCode := 500 + + errMsg := ErrorResponse{ + Type: "error", + Error: ErrorDetail{ + Type: errType, + Message: err.Message, + }, + } + body, _ := json.Marshal(errMsg) + return body, statusCode +} + +// DecodeModelsResponse 解码模型列表响应 +func (a *Adapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) { + return decodeModelsResponse(raw) +} + +// EncodeModelsResponse 编码模型列表响应 +func (a *Adapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) { + return encodeModelsResponse(list) +} + +// DecodeModelInfoResponse 解码模型详情响应 +func (a *Adapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) { + return decodeModelInfoResponse(raw) +} + +// EncodeModelInfoResponse 编码模型详情响应 +func (a *Adapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) { + return encodeModelInfoResponse(info) +} + +// DecodeEmbeddingRequest Anthropic 不支持嵌入 +func (a *Adapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) { + return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口") +} + +// EncodeEmbeddingRequest Anthropic 不支持嵌入 +func (a *Adapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) { + return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口") +} + +// DecodeEmbeddingResponse Anthropic 不支持嵌入 +func (a *Adapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) { + return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口") +} + +// EncodeEmbeddingResponse Anthropic 不支持嵌入 +func (a *Adapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) { + return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口") +} + +// DecodeRerankRequest Anthropic 不支持重排序 +func (a *Adapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) { + return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口") +} + +// EncodeRerankRequest Anthropic 不支持重排序 +func (a *Adapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) { + return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口") +} + +// DecodeRerankResponse Anthropic 不支持重排序 +func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) { + return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口") +} + +// EncodeRerankResponse Anthropic 不支持重排序 +func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) { + return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口") +} diff --git a/backend/internal/conversion/anthropic/adapter_test.go b/backend/internal/conversion/anthropic/adapter_test.go new file mode 100644 index 0000000..417a7e3 --- /dev/null +++ b/backend/internal/conversion/anthropic/adapter_test.go @@ -0,0 +1,210 @@ +package anthropic + +import ( + "encoding/json" + "testing" + + "nex/backend/internal/conversion" + "nex/backend/internal/conversion/canonical" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAdapter_ProtocolName(t *testing.T) { + a := NewAdapter() + assert.Equal(t, "anthropic", a.ProtocolName()) +} + +func TestAdapter_ProtocolVersion(t *testing.T) { + a := NewAdapter() + assert.Equal(t, "2023-06-01", a.ProtocolVersion()) +} + +func TestAdapter_SupportsPassthrough(t *testing.T) { + a := NewAdapter() + assert.True(t, a.SupportsPassthrough()) +} + +func TestAdapter_DetectInterfaceType(t *testing.T) { + a := NewAdapter() + + tests := []struct { + name string + path string + expected conversion.InterfaceType + }{ + {"聊天消息", "/v1/messages", conversion.InterfaceTypeChat}, + {"模型列表", "/v1/models", conversion.InterfaceTypeModels}, + {"模型详情", "/v1/models/claude-3", conversion.InterfaceTypeModelInfo}, + {"未知路径", "/v1/unknown", conversion.InterfaceTypePassthrough}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := a.DetectInterfaceType(tt.path) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestAdapter_BuildUrl(t *testing.T) { + a := NewAdapter() + + tests := []struct { + name string + nativePath string + interfaceType conversion.InterfaceType + expected string + }{ + {"聊天", "/v1/messages", conversion.InterfaceTypeChat, "/v1/messages"}, + {"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"}, + {"默认透传", "/v1/other", conversion.InterfaceTypePassthrough, "/v1/other"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := a.BuildUrl(tt.nativePath, tt.interfaceType) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestAdapter_BuildHeaders_Basic(t *testing.T) { + a := NewAdapter() + provider := conversion.NewTargetProvider("https://api.anthropic.com", "sk-ant-test", "claude-3") + + headers := a.BuildHeaders(provider) + assert.Equal(t, "sk-ant-test", headers["x-api-key"]) + assert.Equal(t, "2023-06-01", headers["anthropic-version"]) + assert.Equal(t, "application/json", headers["Content-Type"]) +} + +func TestAdapter_BuildHeaders_CustomVersion(t *testing.T) { + a := NewAdapter() + provider := conversion.NewTargetProvider("https://api.anthropic.com", "sk-ant-test", "claude-3") + provider.AdapterConfig["anthropic_version"] = "2024-01-01" + + headers := a.BuildHeaders(provider) + assert.Equal(t, "2024-01-01", headers["anthropic-version"]) +} + +func TestAdapter_BuildHeaders_AnthropicBeta(t *testing.T) { + a := NewAdapter() + provider := conversion.NewTargetProvider("https://api.anthropic.com", "sk-ant-test", "claude-3") + provider.AdapterConfig["anthropic_beta"] = []string{"prompt-caching-2024-07-31", "max-tokens-3-5-sonnet-2024-07-15"} + + headers := a.BuildHeaders(provider) + assert.Equal(t, "prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15", headers["anthropic-beta"]) +} + +func TestAdapter_SupportsInterface(t *testing.T) { + a := NewAdapter() + + tests := []struct { + name string + interfaceType conversion.InterfaceType + expected bool + }{ + {"聊天", conversion.InterfaceTypeChat, true}, + {"模型", conversion.InterfaceTypeModels, true}, + {"模型详情", conversion.InterfaceTypeModelInfo, true}, + {"嵌入", conversion.InterfaceTypeEmbeddings, false}, + {"重排序", conversion.InterfaceTypeRerank, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := a.SupportsInterface(tt.interfaceType) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestAdapter_EncodeError(t *testing.T) { + a := NewAdapter() + convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效") + + body, statusCode := a.EncodeError(convErr) + require.Equal(t, 500, statusCode) + + var resp ErrorResponse + require.NoError(t, json.Unmarshal(body, &resp)) + assert.Equal(t, "error", resp.Type) + assert.Equal(t, "INVALID_INPUT", resp.Error.Type) + assert.Equal(t, "参数无效", resp.Error.Message) +} + +func TestAdapter_UnsupportedEmbedding(t *testing.T) { + a := NewAdapter() + + t.Run("解码嵌入请求", func(t *testing.T) { + _, err := a.DecodeEmbeddingRequest([]byte(`{}`)) + require.Error(t, err) + convErr, ok := err.(*conversion.ConversionError) + require.True(t, ok) + assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) + }) + + t.Run("编码嵌入请求", func(t *testing.T) { + provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3") + _, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider) + require.Error(t, err) + convErr, ok := err.(*conversion.ConversionError) + require.True(t, ok) + assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) + }) + + t.Run("解码嵌入响应", func(t *testing.T) { + _, err := a.DecodeEmbeddingResponse([]byte(`{}`)) + require.Error(t, err) + convErr, ok := err.(*conversion.ConversionError) + require.True(t, ok) + assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) + }) + + t.Run("编码嵌入响应", func(t *testing.T) { + _, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{}) + require.Error(t, err) + convErr, ok := err.(*conversion.ConversionError) + require.True(t, ok) + assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) + }) +} + +func TestAdapter_UnsupportedRerank(t *testing.T) { + a := NewAdapter() + + t.Run("解码重排序请求", func(t *testing.T) { + _, err := a.DecodeRerankRequest([]byte(`{}`)) + require.Error(t, err) + convErr, ok := err.(*conversion.ConversionError) + require.True(t, ok) + assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) + }) + + t.Run("编码重排序请求", func(t *testing.T) { + provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3") + _, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider) + require.Error(t, err) + convErr, ok := err.(*conversion.ConversionError) + require.True(t, ok) + assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) + }) + + t.Run("解码重排序响应", func(t *testing.T) { + _, err := a.DecodeRerankResponse([]byte(`{}`)) + require.Error(t, err) + convErr, ok := err.(*conversion.ConversionError) + require.True(t, ok) + assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) + }) + + t.Run("编码重排序响应", func(t *testing.T) { + _, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{}) + require.Error(t, err) + convErr, ok := err.(*conversion.ConversionError) + require.True(t, ok) + assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) + }) +} diff --git a/backend/internal/conversion/anthropic/decoder.go b/backend/internal/conversion/anthropic/decoder.go new file mode 100644 index 0000000..b194cec --- /dev/null +++ b/backend/internal/conversion/anthropic/decoder.go @@ -0,0 +1,427 @@ +package anthropic + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "nex/backend/internal/conversion" + "nex/backend/internal/conversion/canonical" +) + +// decodeRequest 将 Anthropic 请求解码为 Canonical 请求 +func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) { + var req MessagesRequest + if err := json.Unmarshal(body, &req); err != nil { + return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 Anthropic 请求失败").WithCause(err) + } + + if strings.TrimSpace(req.Model) == "" { + return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "model 字段不能为空") + } + if len(req.Messages) == 0 { + return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "messages 字段不能为空") + } + + system := decodeSystem(req.System) + + var canonicalMsgs []canonical.CanonicalMessage + for _, msg := range req.Messages { + decoded := decodeMessage(msg) + canonicalMsgs = append(canonicalMsgs, decoded...) + } + + tools := decodeTools(req.Tools) + toolChoice := decodeToolChoice(req.ToolChoice) + params := decodeParameters(&req) + thinking := decodeThinking(req.Thinking, req.OutputConfig) + outputFormat := decodeOutputFormat(req.OutputConfig) + + var parallelToolUse *bool + if req.DisableParallelToolUse != nil && *req.DisableParallelToolUse { + val := false + parallelToolUse = &val + } + + var userID string + if req.Metadata != nil { + userID = req.Metadata.UserID + } + + return &canonical.CanonicalRequest{ + Model: req.Model, + System: system, + Messages: canonicalMsgs, + Tools: tools, + ToolChoice: toolChoice, + Parameters: params, + Thinking: thinking, + Stream: req.Stream, + UserID: userID, + OutputFormat: outputFormat, + ParallelToolUse: parallelToolUse, + }, nil +} + +// decodeSystem 解码系统消息 +func decodeSystem(system any) any { + if system == nil { + return nil + } + switch v := system.(type) { + case string: + if v == "" { + return nil + } + return v + case []any: + var blocks []canonical.SystemBlock + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if text, ok := m["text"].(string); ok { + blocks = append(blocks, canonical.SystemBlock{Text: text}) + } + } + } + if len(blocks) == 0 { + return nil + } + return blocks + default: + return fmt.Sprintf("%v", v) + } +} + +// decodeMessage 解码 Anthropic 消息 +func decodeMessage(msg Message) []canonical.CanonicalMessage { + switch msg.Role { + case "user": + blocks := decodeContentBlocks(msg.Content) + var toolResults []canonical.ContentBlock + var others []canonical.ContentBlock + for _, b := range blocks { + if b.Type == "tool_result" { + toolResults = append(toolResults, b) + } else { + others = append(others, b) + } + } + var result []canonical.CanonicalMessage + if len(others) > 0 { + result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: others}) + } + if len(toolResults) > 0 { + result = append(result, canonical.CanonicalMessage{Role: canonical.RoleTool, Content: toolResults}) + } + if len(result) == 0 { + result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}}) + } + return result + + case "assistant": + blocks := decodeContentBlocks(msg.Content) + if len(blocks) == 0 { + blocks = append(blocks, canonical.NewTextBlock("")) + } + return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}} + } + return nil +} + +// decodeContentBlocks 解码内容块列表 +func decodeContentBlocks(content any) []canonical.ContentBlock { + switch v := content.(type) { + case string: + return []canonical.ContentBlock{canonical.NewTextBlock(v)} + case []any: + var blocks []canonical.ContentBlock + for _, item := range v { + if m, ok := item.(map[string]any); ok { + block := decodeSingleContentBlock(m) + if block != nil { + blocks = append(blocks, *block) + } + } + } + if len(blocks) > 0 { + return blocks + } + return []canonical.ContentBlock{canonical.NewTextBlock("")} + case nil: + return []canonical.ContentBlock{canonical.NewTextBlock("")} + default: + return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))} + } +} + +// decodeSingleContentBlock 解码单个内容块 +func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock { + t, _ := m["type"].(string) + switch t { + case "text": + text, _ := m["text"].(string) + return &canonical.ContentBlock{Type: "text", Text: text} + case "tool_use": + id, _ := m["id"].(string) + name, _ := m["name"].(string) + input, _ := json.Marshal(m["input"]) + return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input} + case "tool_result": + toolUseID, _ := m["tool_use_id"].(string) + isErr := false + if ie, ok := m["is_error"].(bool); ok { + isErr = ie + } + var content json.RawMessage + if c, ok := m["content"]; ok { + switch cv := c.(type) { + case string: + content = json.RawMessage(fmt.Sprintf("%q", cv)) + default: + content, _ = json.Marshal(cv) + } + } else { + content = json.RawMessage(`""`) + } + return &canonical.ContentBlock{ + Type: "tool_result", + ToolUseID: toolUseID, + Content: content, + IsError: &isErr, + } + case "thinking": + thinking, _ := m["thinking"].(string) + return &canonical.ContentBlock{Type: "thinking", Thinking: thinking} + case "redacted_thinking": + // 丢弃 + return nil + } + return nil +} + +// decodeTools 解码工具定义 +func decodeTools(tools []Tool) []canonical.CanonicalTool { + if len(tools) == 0 { + return nil + } + result := make([]canonical.CanonicalTool, len(tools)) + for i, t := range tools { + result[i] = canonical.CanonicalTool{ + Name: t.Name, + Description: t.Description, + InputSchema: t.InputSchema, + } + } + return result +} + +// decodeToolChoice 解码工具选择 +func decodeToolChoice(toolChoice any) *canonical.ToolChoice { + if toolChoice == nil { + return nil + } + switch v := toolChoice.(type) { + case string: + switch v { + case "auto": + return canonical.NewToolChoiceAuto() + case "none": + return canonical.NewToolChoiceNone() + case "any": + return canonical.NewToolChoiceAny() + } + case map[string]any: + t, _ := v["type"].(string) + switch t { + case "auto": + return canonical.NewToolChoiceAuto() + case "none": + return canonical.NewToolChoiceNone() + case "any": + return canonical.NewToolChoiceAny() + case "tool": + name, _ := v["name"].(string) + return canonical.NewToolChoiceNamed(name) + } + } + return nil +} + +// decodeParameters 解码请求参数 +func decodeParameters(req *MessagesRequest) canonical.RequestParameters { + params := canonical.RequestParameters{ + Temperature: req.Temperature, + TopP: req.TopP, + TopK: req.TopK, + } + if req.MaxTokens > 0 { + val := req.MaxTokens + params.MaxTokens = &val + } + if len(req.StopSequences) > 0 { + params.StopSequences = req.StopSequences + } + return params +} + +// decodeThinking 解码思考配置 +func decodeThinking(thinking *ThinkingConfig, outputConfig *OutputConfig) *canonical.ThinkingConfig { + if thinking == nil { + return nil + } + cfg := &canonical.ThinkingConfig{ + Type: thinking.Type, + BudgetTokens: thinking.BudgetTokens, + } + if outputConfig != nil && outputConfig.Effort != "" { + cfg.Effort = outputConfig.Effort + } + return cfg +} + +// decodeOutputFormat 解码输出格式 +func decodeOutputFormat(outputConfig *OutputConfig) *canonical.OutputFormat { + if outputConfig == nil || outputConfig.Format == nil { + return nil + } + if outputConfig.Format.Type == "json_schema" && outputConfig.Format.Schema != nil { + return &canonical.OutputFormat{ + Type: "json_schema", + Name: "output", + Schema: outputConfig.Format.Schema, + Strict: boolPtr(true), + } + } + return nil +} + +// decodeResponse 将 Anthropic 响应解码为 Canonical 响应 +func decodeResponse(body []byte) (*canonical.CanonicalResponse, error) { + var resp MessagesResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 Anthropic 响应失败").WithCause(err) + } + + var blocks []canonical.ContentBlock + for _, block := range resp.Content { + switch block.Type { + case "text": + blocks = append(blocks, canonical.NewTextBlock(block.Text)) + case "tool_use": + blocks = append(blocks, canonical.NewToolUseBlock(block.ID, block.Name, block.Input)) + case "thinking": + blocks = append(blocks, canonical.NewThinkingBlock(block.Thinking)) + case "redacted_thinking": + // 丢弃 + } + } + if len(blocks) == 0 { + blocks = append(blocks, canonical.NewTextBlock("")) + } + + sr := mapStopReason(resp.StopReason) + usage := canonical.CanonicalUsage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + } + if resp.Usage.CacheReadInputTokens != nil { + usage.CacheReadTokens = resp.Usage.CacheReadInputTokens + } + if resp.Usage.CacheCreationInputTokens != nil { + usage.CacheCreationTokens = resp.Usage.CacheCreationInputTokens + } + + return &canonical.CanonicalResponse{ + ID: resp.ID, + Model: resp.Model, + Content: blocks, + StopReason: &sr, + Usage: usage, + }, nil +} + +// mapStopReason 映射停止原因 +func mapStopReason(reason string) canonical.StopReason { + switch reason { + case "end_turn": + return canonical.StopReasonEndTurn + case "max_tokens": + return canonical.StopReasonMaxTokens + case "tool_use": + return canonical.StopReasonToolUse + case "stop_sequence": + return canonical.StopReasonStopSequence + case "pause_turn": + return canonical.StopReason("pause_turn") + case "refusal": + return canonical.StopReasonRefusal + default: + return canonical.StopReasonEndTurn + } +} + +// decodeModelsResponse 解码模型列表响应 +func decodeModelsResponse(body []byte) (*canonical.CanonicalModelList, error) { + var resp ModelsResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, err + } + models := make([]canonical.CanonicalModel, len(resp.Data)) + for i, m := range resp.Data { + name := m.DisplayName + if name == "" { + name = m.ID + } + models[i] = canonical.CanonicalModel{ + ID: m.ID, + Name: name, + Created: parseTimestamp(m.CreatedAt), + OwnedBy: "anthropic", + } + } + return &canonical.CanonicalModelList{Models: models}, nil +} + +// decodeModelInfoResponse 解码模型详情响应 +func decodeModelInfoResponse(body []byte) (*canonical.CanonicalModelInfo, error) { + var resp ModelInfoResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, err + } + name := resp.DisplayName + if name == "" { + name = resp.ID + } + return &canonical.CanonicalModelInfo{ + ID: resp.ID, + Name: name, + Created: parseTimestamp(resp.CreatedAt), + OwnedBy: "anthropic", + }, nil +} + +// parseTimestamp 解析 RFC 3339 时间戳为 Unix +func parseTimestamp(s string) int64 { + if s == "" { + return 0 + } + t, err := time.Parse(time.RFC3339, s) + if err != nil { + return 0 + } + return t.Unix() +} + +// formatTimestamp 将 Unix 时间戳格式化为 RFC 3339 +func formatTimestamp(unix int64) string { + if unix == 0 { + return time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC).Format(time.RFC3339) + } + return time.Unix(unix, 0).UTC().Format(time.RFC3339) +} + +// boolPtr 返回 bool 指针 +func boolPtr(b bool) *bool { + return &b +} diff --git a/backend/internal/conversion/anthropic/decoder_test.go b/backend/internal/conversion/anthropic/decoder_test.go new file mode 100644 index 0000000..6e5efc0 --- /dev/null +++ b/backend/internal/conversion/anthropic/decoder_test.go @@ -0,0 +1,331 @@ +package anthropic + +import ( + "fmt" + "testing" + + "nex/backend/internal/conversion/canonical" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecodeRequest_Basic(t *testing.T) { + body := []byte(`{ + "model": "claude-3", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "你好"} + ] + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + assert.Equal(t, "claude-3", req.Model) + assert.Len(t, req.Messages, 1) + assert.Equal(t, canonical.RoleUser, req.Messages[0].Role) + assert.NotNil(t, req.Parameters.MaxTokens) + assert.Equal(t, 1024, *req.Parameters.MaxTokens) +} + +func TestDecodeRequest_System(t *testing.T) { + body := []byte(`{ + "model": "claude-3", + "max_tokens": 1024, + "system": "你是助手", + "messages": [ + {"role": "user", "content": "你好"} + ] + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + assert.Equal(t, "你是助手", req.System) +} + +func TestDecodeRequest_SystemBlocks(t *testing.T) { + body := []byte(`{ + "model": "claude-3", + "max_tokens": 1024, + "system": [{"text": "指令1"}, {"text": "指令2"}], + "messages": [ + {"role": "user", "content": "你好"} + ] + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + blocks, ok := req.System.([]canonical.SystemBlock) + require.True(t, ok) + assert.Len(t, blocks, 2) + assert.Equal(t, "指令1", blocks[0].Text) +} + +func TestDecodeRequest_ToolResultSplit(t *testing.T) { + body := []byte(`{ + "model": "claude-3", + "max_tokens": 1024, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "查询天气"}, + {"type": "tool_result", "tool_use_id": "tool_1", "content": "晴天"} + ] + } + ] + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + // 用户消息中的 tool_result 应被拆分为独立的 tool 消息 + assert.Len(t, req.Messages, 2) + assert.Equal(t, canonical.RoleUser, req.Messages[0].Role) + assert.Equal(t, canonical.RoleTool, req.Messages[1].Role) +} + +func TestDecodeRequest_MissingModel(t *testing.T) { + body := []byte(`{"max_tokens": 1024, "messages": [{"role": "user", "content": "hi"}]}`) + _, err := decodeRequest(body) + require.Error(t, err) + assert.Contains(t, err.Error(), "INVALID_INPUT") +} + +func TestDecodeRequest_MissingMessages(t *testing.T) { + body := []byte(`{"model": "claude-3", "max_tokens": 1024}`) + _, err := decodeRequest(body) + require.Error(t, err) + assert.Contains(t, err.Error(), "INVALID_INPUT") +} + +func TestDecodeResponse_Basic(t *testing.T) { + body := []byte(`{ + "id": "msg_123", + "type": "message", + "role": "assistant", + "model": "claude-3", + "content": [{"type": "text", "text": "你好"}], + "stop_reason": "end_turn", + "usage": {"input_tokens": 10, "output_tokens": 5} + }`) + + resp, err := decodeResponse(body) + require.NoError(t, err) + assert.Equal(t, "msg_123", resp.ID) + assert.Equal(t, "claude-3", resp.Model) + assert.Len(t, resp.Content, 1) + assert.Equal(t, "你好", resp.Content[0].Text) + assert.Equal(t, canonical.StopReasonEndTurn, *resp.StopReason) + assert.Equal(t, 10, resp.Usage.InputTokens) +} + +func TestDecodeResponse_Thinking(t *testing.T) { + body := []byte(`{ + "id": "msg_456", + "type": "message", + "role": "assistant", + "model": "claude-3", + "content": [ + {"type": "thinking", "thinking": "思考过程"}, + {"type": "text", "text": "回答"} + ], + "stop_reason": "end_turn", + "usage": {"input_tokens": 10, "output_tokens": 20} + }`) + + resp, err := decodeResponse(body) + require.NoError(t, err) + assert.Len(t, resp.Content, 2) + assert.Equal(t, "thinking", resp.Content[0].Type) + assert.Equal(t, "思考过程", resp.Content[0].Thinking) + assert.Equal(t, "text", resp.Content[1].Type) + assert.Equal(t, "回答", resp.Content[1].Text) +} + +func TestDecodeModelsResponse(t *testing.T) { + body := []byte(`{ + "data": [ + {"id": "claude-3-opus", "type": "model", "display_name": "Claude 3 Opus", "created_at": "2024-01-15T00:00:00Z"}, + {"id": "claude-3-sonnet", "type": "model", "created_at": "2024-02-01T00:00:00Z"} + ], + "has_more": false + }`) + + list, err := decodeModelsResponse(body) + require.NoError(t, err) + assert.Len(t, list.Models, 2) + assert.Equal(t, "claude-3-opus", list.Models[0].ID) + assert.Equal(t, "Claude 3 Opus", list.Models[0].Name) + // created_at RFC3339 → Unix + assert.NotEqual(t, int64(0), list.Models[0].Created) + // 无 display_name 时使用 ID + assert.Equal(t, "claude-3-sonnet", list.Models[1].Name) +} + +func TestDecodeRequest_InvalidJSON(t *testing.T) { + _, err := decodeRequest([]byte(`invalid json`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "JSON_PARSE_ERROR") +} + +func TestDecodeRequest_Thinking(t *testing.T) { + body := []byte(`{ + "model": "claude-3", + "max_tokens": 1024, + "messages": [{"role": "user", "content": "hi"}], + "thinking": {"type": "enabled", "budget_tokens": 5000} + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + require.NotNil(t, req.Thinking) + assert.Equal(t, "enabled", req.Thinking.Type) + require.NotNil(t, req.Thinking.BudgetTokens) + assert.Equal(t, 5000, *req.Thinking.BudgetTokens) +} + +func TestDecodeRequest_ThinkingAdaptive(t *testing.T) { + body := []byte(`{ + "model": "claude-3", + "max_tokens": 1024, + "messages": [{"role": "user", "content": "hi"}], + "thinking": {"type": "adaptive"} + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + require.NotNil(t, req.Thinking) + assert.Equal(t, "adaptive", req.Thinking.Type) +} + +func TestDecodeRequest_OutputConfig(t *testing.T) { + body := []byte(`{ + "model": "claude-3", + "max_tokens": 1024, + "messages": [{"role": "user", "content": "hi"}], + "output_config": { + "format": { + "type": "json_schema", + "schema": {"type": "object", "properties": {"name": {"type": "string"}}} + } + } + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + require.NotNil(t, req.OutputFormat) + assert.Equal(t, "json_schema", req.OutputFormat.Type) + assert.NotNil(t, req.OutputFormat.Schema) +} + +func TestDecodeRequest_DisableParallelToolUse(t *testing.T) { + body := []byte(`{ + "model": "claude-3", + "max_tokens": 1024, + "messages": [{"role": "user", "content": "hi"}], + "disable_parallel_tool_use": true + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + require.NotNil(t, req.ParallelToolUse) + assert.False(t, *req.ParallelToolUse) +} + +func TestDecodeResponse_ToolUse(t *testing.T) { + body := []byte(`{ + "id": "msg_tool", + "type": "message", + "role": "assistant", + "model": "claude-3", + "content": [ + {"type": "tool_use", "id": "tool_1", "name": "search", "input": {"q": "test"}} + ], + "stop_reason": "tool_use", + "usage": {"input_tokens": 10, "output_tokens": 5} + }`) + + resp, err := decodeResponse(body) + require.NoError(t, err) + require.Len(t, resp.Content, 1) + assert.Equal(t, "tool_use", resp.Content[0].Type) + assert.Equal(t, "tool_1", resp.Content[0].ID) + assert.Equal(t, "search", resp.Content[0].Name) + assert.NotNil(t, resp.Content[0].Input) +} + +func TestDecodeResponse_RedactedThinking(t *testing.T) { + body := []byte(`{ + "id": "msg_redacted", + "type": "message", + "role": "assistant", + "model": "claude-3", + "content": [ + {"type": "redacted_thinking", "data": "..."}, + {"type": "text", "text": "回答"} + ], + "stop_reason": "end_turn", + "usage": {"input_tokens": 10, "output_tokens": 5} + }`) + + resp, err := decodeResponse(body) + require.NoError(t, err) + assert.Len(t, resp.Content, 1) + assert.Equal(t, "text", resp.Content[0].Type) + assert.Equal(t, "回答", resp.Content[0].Text) +} + +func TestDecodeResponse_StopReasons(t *testing.T) { + tests := []struct { + name string + reason string + want canonical.StopReason + }{ + {"end_turn→end_turn", "end_turn", canonical.StopReasonEndTurn}, + {"max_tokens→max_tokens", "max_tokens", canonical.StopReasonMaxTokens}, + {"tool_use→tool_use", "tool_use", canonical.StopReasonToolUse}, + {"stop_sequence→stop_sequence", "stop_sequence", canonical.StopReasonStopSequence}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body := []byte(fmt.Sprintf(`{ + "id": "msg-1", + "type": "message", + "role": "assistant", + "model": "claude-3", + "content": [{"type": "text", "text": "ok"}], + "stop_reason": "%s", + "usage": {"input_tokens": 1, "output_tokens": 1} + }`, tt.reason)) + + resp, err := decodeResponse(body) + require.NoError(t, err) + require.NotNil(t, resp.StopReason) + assert.Equal(t, tt.want, *resp.StopReason) + }) + } +} + +func TestDecodeResponse_Usage(t *testing.T) { + body := []byte(`{ + "id": "msg_usage", + "type": "message", + "role": "assistant", + "model": "claude-3", + "content": [{"type": "text", "text": "ok"}], + "stop_reason": "end_turn", + "usage": { + "input_tokens": 100, + "output_tokens": 50, + "cache_read_input_tokens": 30 + } + }`) + + resp, err := decodeResponse(body) + require.NoError(t, err) + assert.Equal(t, 100, resp.Usage.InputTokens) + assert.Equal(t, 50, resp.Usage.OutputTokens) + require.NotNil(t, resp.Usage.CacheReadTokens) + assert.Equal(t, 30, *resp.Usage.CacheReadTokens) +} diff --git a/backend/internal/conversion/anthropic/encoder.go b/backend/internal/conversion/anthropic/encoder.go new file mode 100644 index 0000000..b79f756 --- /dev/null +++ b/backend/internal/conversion/anthropic/encoder.go @@ -0,0 +1,449 @@ +package anthropic + +import ( + "encoding/json" + "fmt" + + "nex/backend/internal/conversion" + "nex/backend/internal/conversion/canonical" +) + +// encodeRequest 将 Canonical 请求编码为 Anthropic 请求 +func encodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) { + result := map[string]any{ + "model": provider.ModelName, + "stream": req.Stream, + } + + // max_tokens 必填 + if req.Parameters.MaxTokens != nil { + result["max_tokens"] = *req.Parameters.MaxTokens + } else { + result["max_tokens"] = 4096 + } + + // 系统消息 + if req.System != nil { + result["system"] = encodeSystem(req.System) + } + + // 消息 + result["messages"] = encodeMessages(req.Messages) + + // 参数 + if req.Parameters.Temperature != nil { + result["temperature"] = *req.Parameters.Temperature + } + if req.Parameters.TopP != nil { + result["top_p"] = *req.Parameters.TopP + } + if req.Parameters.TopK != nil { + result["top_k"] = *req.Parameters.TopK + } + if len(req.Parameters.StopSequences) > 0 { + result["stop_sequences"] = req.Parameters.StopSequences + } + + // 工具 + if len(req.Tools) > 0 { + tools := make([]map[string]any, len(req.Tools)) + for i, t := range req.Tools { + tool := map[string]any{ + "name": t.Name, + "input_schema": t.InputSchema, + } + if t.Description != "" { + tool["description"] = t.Description + } + tools[i] = tool + } + result["tools"] = tools + } + if req.ToolChoice != nil { + result["tool_choice"] = encodeToolChoice(req.ToolChoice) + } + + // 公共字段 + if req.UserID != "" { + result["metadata"] = map[string]any{"user_id": req.UserID} + } + if req.ParallelToolUse != nil && !*req.ParallelToolUse { + result["disable_parallel_tool_use"] = true + } + if req.Thinking != nil { + result["thinking"] = encodeThinkingConfig(req.Thinking) + } + + // output_config + outputConfig := map[string]any{} + hasOutputConfig := false + if req.OutputFormat != nil { + of := encodeOutputFormat(req.OutputFormat) + if of != nil { + outputConfig["format"] = of + hasOutputConfig = true + } + } + if req.Thinking != nil && req.Thinking.Effort != "" { + outputConfig["effort"] = req.Thinking.Effort + hasOutputConfig = true + } + if hasOutputConfig { + result["output_config"] = outputConfig + } + + body, err := json.Marshal(result) + if err != nil { + return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 Anthropic 请求失败").WithCause(err) + } + return body, nil +} + +// encodeSystem 编码系统消息 +func encodeSystem(system any) any { + switch v := system.(type) { + case string: + return v + case []canonical.SystemBlock: + blocks := make([]map[string]any, len(v)) + for i, b := range v { + blocks[i] = map[string]any{"text": b.Text} + } + return blocks + default: + return fmt.Sprintf("%v", v) + } +} + +// encodeMessages 编码消息列表(含角色约束处理) +func encodeMessages(msgs []canonical.CanonicalMessage) []map[string]any { + var result []map[string]any + + for _, msg := range msgs { + switch msg.Role { + case canonical.RoleUser: + result = append(result, map[string]any{ + "role": "user", + "content": encodeContentBlocks(msg.Content), + }) + case canonical.RoleAssistant: + result = append(result, map[string]any{ + "role": "assistant", + "content": encodeContentBlocks(msg.Content), + }) + case canonical.RoleTool: + // tool 角色合并到相邻 user 消息 + toolResults := filterToolResults(msg.Content) + if len(result) > 0 && result[len(result)-1]["role"] == "user" { + // 合并到最后一条 user 消息 + lastContent, ok := result[len(result)-1]["content"].([]map[string]any) + if ok { + result[len(result)-1]["content"] = append(lastContent, toolResults...) + } else { + result[len(result)-1]["content"] = toolResults + } + } else { + result = append(result, map[string]any{ + "role": "user", + "content": toolResults, + }) + } + } + } + + // 确保首消息为 user + if len(result) > 0 && result[0]["role"] != "user" { + result = append([]map[string]any{{"role": "user", "content": []map[string]any{}}}, result...) + } + + // 合并连续同角色消息 + result = mergeConsecutiveRoles(result) + + return result +} + +// encodeContentBlocks 编码内容块列表 +func encodeContentBlocks(blocks []canonical.ContentBlock) []map[string]any { + result := make([]map[string]any, 0, len(blocks)) + for _, b := range blocks { + switch b.Type { + case "text": + result = append(result, map[string]any{"type": "text", "text": b.Text}) + case "tool_use": + m := map[string]any{ + "type": "tool_use", + "id": b.ID, + "name": b.Name, + "input": b.Input, + } + if b.Input == nil { + m["input"] = map[string]any{} + } + result = append(result, m) + case "tool_result": + m := map[string]any{ + "type": "tool_result", + "tool_use_id": b.ToolUseID, + } + if b.Content != nil { + var contentStr string + if json.Unmarshal(b.Content, &contentStr) == nil { + m["content"] = contentStr + } else { + m["content"] = string(b.Content) + } + } else { + m["content"] = "" + } + if b.IsError != nil { + m["is_error"] = *b.IsError + } + result = append(result, m) + case "thinking": + result = append(result, map[string]any{"type": "thinking", "thinking": b.Thinking}) + } + } + if len(result) == 0 { + return []map[string]any{{"type": "text", "text": ""}} + } + return result +} + +// filterToolResults 过滤工具结果 +func filterToolResults(blocks []canonical.ContentBlock) []map[string]any { + var result []map[string]any + for _, b := range blocks { + if b.Type == "tool_result" { + m := map[string]any{ + "type": "tool_result", + "tool_use_id": b.ToolUseID, + } + if b.Content != nil { + var contentStr string + if json.Unmarshal(b.Content, &contentStr) == nil { + m["content"] = contentStr + } else { + m["content"] = string(b.Content) + } + } else { + m["content"] = "" + } + if b.IsError != nil { + m["is_error"] = *b.IsError + } + result = append(result, m) + } + } + return result +} + +// encodeToolChoice 编码工具选择 +func encodeToolChoice(choice *canonical.ToolChoice) any { + switch choice.Type { + case "auto": + return map[string]any{"type": "auto"} + case "none": + return map[string]any{"type": "none"} + case "any": + return map[string]any{"type": "any"} + case "tool": + return map[string]any{"type": "tool", "name": choice.Name} + } + return map[string]any{"type": "auto"} +} + +// encodeThinkingConfig 编码思考配置 +func encodeThinkingConfig(cfg *canonical.ThinkingConfig) map[string]any { + switch cfg.Type { + case "enabled": + m := map[string]any{"type": "enabled"} + if cfg.BudgetTokens != nil { + m["budget_tokens"] = *cfg.BudgetTokens + } + return m + case "disabled": + return map[string]any{"type": "disabled"} + case "adaptive": + return map[string]any{"type": "adaptive"} + } + return map[string]any{"type": "disabled"} +} + +// encodeOutputFormat 编码输出格式 +func encodeOutputFormat(format *canonical.OutputFormat) map[string]any { + if format == nil { + return nil + } + switch format.Type { + case "json_schema": + schema := format.Schema + if schema == nil { + schema = json.RawMessage(`{"type":"object"}`) + } + return map[string]any{ + "type": "json_schema", + "schema": schema, + } + case "json_object": + return map[string]any{ + "type": "json_schema", + "schema": map[string]any{"type": "object"}, + } + case "text": + return nil + } + return nil +} + +// encodeResponse 将 Canonical 响应编码为 Anthropic 响应 +func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) { + blocks := make([]map[string]any, 0, len(resp.Content)) + for _, b := range resp.Content { + switch b.Type { + case "text": + blocks = append(blocks, map[string]any{"type": "text", "text": b.Text}) + case "tool_use": + m := map[string]any{ + "type": "tool_use", + "id": b.ID, + "name": b.Name, + "input": b.Input, + } + if b.Input == nil { + m["input"] = map[string]any{} + } + blocks = append(blocks, m) + case "thinking": + blocks = append(blocks, map[string]any{"type": "thinking", "thinking": b.Thinking}) + } + } + + sr := "end_turn" + if resp.StopReason != nil { + sr = mapCanonicalStopReason(*resp.StopReason) + } + + usage := map[string]any{ + "input_tokens": resp.Usage.InputTokens, + "output_tokens": resp.Usage.OutputTokens, + } + if resp.Usage.CacheReadTokens != nil { + usage["cache_read_input_tokens"] = *resp.Usage.CacheReadTokens + } + if resp.Usage.CacheCreationTokens != nil { + usage["cache_creation_input_tokens"] = *resp.Usage.CacheCreationTokens + } + + result := map[string]any{ + "id": resp.ID, + "type": "message", + "role": "assistant", + "model": resp.Model, + "content": blocks, + "stop_reason": sr, + "stop_sequence": nil, + "usage": usage, + } + + body, err := json.Marshal(result) + if err != nil { + return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 Anthropic 响应失败").WithCause(err) + } + return body, nil +} + +// mapCanonicalStopReason 映射 Canonical 停止原因到 Anthropic +func mapCanonicalStopReason(reason canonical.StopReason) string { + switch reason { + case canonical.StopReasonEndTurn, canonical.StopReasonContentFilter: + return "end_turn" + case canonical.StopReasonMaxTokens: + return "max_tokens" + case canonical.StopReasonToolUse: + return "tool_use" + case canonical.StopReasonStopSequence: + return "stop_sequence" + case canonical.StopReasonRefusal: + return "refusal" + default: + return "end_turn" + } +} + +// encodeModelsResponse 编码模型列表响应 +func encodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) { + data := make([]map[string]any, len(list.Models)) + for i, m := range list.Models { + name := m.Name + if name == "" { + name = m.ID + } + data[i] = map[string]any{ + "id": m.ID, + "type": "model", + "display_name": name, + "created_at": formatTimestamp(m.Created), + } + } + + var firstID, lastID *string + if len(list.Models) > 0 { + fid := list.Models[0].ID + firstID = &fid + lid := list.Models[len(list.Models)-1].ID + lastID = &lid + } + + return json.Marshal(map[string]any{ + "data": data, + "has_more": false, + "first_id": firstID, + "last_id": lastID, + }) +} + +// encodeModelInfoResponse 编码模型详情响应 +func encodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) { + name := info.Name + if name == "" { + name = info.ID + } + return json.Marshal(map[string]any{ + "id": info.ID, + "type": "model", + "display_name": name, + "created_at": formatTimestamp(info.Created), + }) +} + +// mergeConsecutiveRoles 合并连续同角色消息 +func mergeConsecutiveRoles(messages []map[string]any) []map[string]any { + if len(messages) <= 1 { + return messages + } + var result []map[string]any + for _, msg := range messages { + if len(result) > 0 { + lastRole := result[len(result)-1]["role"] + currRole := msg["role"] + if lastRole == currRole { + // 合并 content + lastContent := result[len(result)-1]["content"] + currContent := msg["content"] + switch lv := lastContent.(type) { + case []map[string]any: + if cv, ok := currContent.([]map[string]any); ok { + result[len(result)-1]["content"] = append(lv, cv...) + } + case string: + if cv, ok := currContent.(string); ok { + result[len(result)-1]["content"] = lv + cv + } + } + continue + } + } + result = append(result, msg) + } + return result +} diff --git a/backend/internal/conversion/anthropic/encoder_test.go b/backend/internal/conversion/anthropic/encoder_test.go new file mode 100644 index 0000000..68e2dce --- /dev/null +++ b/backend/internal/conversion/anthropic/encoder_test.go @@ -0,0 +1,350 @@ +package anthropic + +import ( + "encoding/json" + "testing" + "time" + + "nex/backend/internal/conversion" + "nex/backend/internal/conversion/canonical" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncodeRequest_Basic(t *testing.T) { + maxTokens := 1024 + req := &canonical.CanonicalRequest{ + Model: "claude-3", + Stream: true, + Parameters: canonical.RequestParameters{MaxTokens: &maxTokens}, + Messages: []canonical.CanonicalMessage{ + {Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}, + }, + } + provider := conversion.NewTargetProvider("", "key", "my-model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + assert.Equal(t, "my-model", result["model"]) + assert.Equal(t, true, result["stream"]) + assert.Equal(t, float64(1024), result["max_tokens"]) + + msgs := result["messages"].([]any) + assert.Len(t, msgs, 1) +} + +func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) { + maxTokens := 1024 + req := &canonical.CanonicalRequest{ + Model: "claude-3", + Parameters: canonical.RequestParameters{MaxTokens: &maxTokens}, + Messages: []canonical.CanonicalMessage{ + {Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("查询")}}, + {Role: canonical.RoleAssistant, Content: []canonical.ContentBlock{canonical.NewToolUseBlock("tool_1", "search", json.RawMessage(`{"q":"test"}`))}}, + {Role: canonical.RoleTool, Content: []canonical.ContentBlock{canonical.NewToolResultBlock("tool_1", "结果", false)}}, + }, + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + msgs := result["messages"].([]any) + + // tool 消息应被合并到相邻 user 消息 + foundToolResult := false + for _, m := range msgs { + msgMap := m.(map[string]any) + if msgMap["role"] == "user" { + content, ok := msgMap["content"].([]any) + if ok { + for _, c := range content { + block := c.(map[string]any) + if block["type"] == "tool_result" { + foundToolResult = true + } + } + } + } + } + assert.True(t, foundToolResult) +} + +func TestEncodeRequest_FirstUserGuarantee(t *testing.T) { + maxTokens := 1024 + req := &canonical.CanonicalRequest{ + Model: "claude-3", + Parameters: canonical.RequestParameters{MaxTokens: &maxTokens}, + Messages: []canonical.CanonicalMessage{ + {Role: canonical.RoleAssistant, Content: []canonical.ContentBlock{canonical.NewTextBlock("前置")}}, + {Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}, + }, + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + msgs := result["messages"].([]any) + firstMsg := msgs[0].(map[string]any) + assert.Equal(t, "user", firstMsg["role"]) +} + +func TestEncodeRequest_ThinkingEnabled(t *testing.T) { + budget := 10000 + maxTokens := 8096 + req := &canonical.CanonicalRequest{ + Model: "claude-3", + Parameters: canonical.RequestParameters{MaxTokens: &maxTokens}, + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + Thinking: &canonical.ThinkingConfig{Type: "enabled", BudgetTokens: &budget}, + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + thinking, ok := result["thinking"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "enabled", thinking["type"]) + assert.Equal(t, float64(10000), thinking["budget_tokens"]) +} + +func TestEncodeResponse_Basic(t *testing.T) { + sr := canonical.StopReasonEndTurn + resp := &canonical.CanonicalResponse{ + ID: "msg_1", + Model: "claude-3", + Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")}, + StopReason: &sr, + Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5}, + } + + body, err := encodeResponse(resp) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + assert.Equal(t, "msg_1", result["id"]) + assert.Equal(t, "message", result["type"]) + assert.Equal(t, "assistant", result["role"]) + assert.Equal(t, "end_turn", result["stop_reason"]) + + content := result["content"].([]any) + assert.Len(t, content, 1) + block := content[0].(map[string]any) + assert.Equal(t, "text", block["type"]) + assert.Equal(t, "你好", block["text"]) +} + +func TestEncodeModelsResponse(t *testing.T) { + ts := time.Date(2024, 3, 15, 0, 0, 0, 0, time.UTC).Unix() + list := &canonical.CanonicalModelList{ + Models: []canonical.CanonicalModel{ + {ID: "claude-3-opus", Name: "Claude 3 Opus", Created: ts, OwnedBy: "anthropic"}, + }, + } + + body, err := encodeModelsResponse(list) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + data := result["data"].([]any) + assert.Len(t, data, 1) + + model := data[0].(map[string]any) + assert.Equal(t, "claude-3-opus", model["id"]) + // created 应为 RFC3339 格式 + createdAt, ok := model["created_at"].(string) + assert.True(t, ok) + assert.Contains(t, createdAt, "2024") +} + +func TestEncodeRequest_ThinkingDisabled(t *testing.T) { + maxTokens := 1024 + req := &canonical.CanonicalRequest{ + Model: "claude-3", + Parameters: canonical.RequestParameters{MaxTokens: &maxTokens}, + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + _, hasThinking := result["thinking"] + assert.False(t, hasThinking) +} + +func TestEncodeRequest_ThinkingAdaptive(t *testing.T) { + maxTokens := 1024 + req := &canonical.CanonicalRequest{ + Model: "claude-3", + Parameters: canonical.RequestParameters{MaxTokens: &maxTokens}, + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + Thinking: &canonical.ThinkingConfig{Type: "adaptive"}, + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + thinking, ok := result["thinking"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "adaptive", thinking["type"]) +} + +func TestEncodeRequest_OutputFormat_JSONSchema(t *testing.T) { + maxTokens := 1024 + schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`) + req := &canonical.CanonicalRequest{ + Model: "claude-3", + Parameters: canonical.RequestParameters{MaxTokens: &maxTokens}, + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + OutputFormat: &canonical.OutputFormat{ + Type: "json_schema", + Schema: schema, + }, + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + oc, ok := result["output_config"].(map[string]any) + require.True(t, ok) + format, ok := oc["format"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "json_schema", format["type"]) + assert.NotNil(t, format["schema"]) +} + +func TestEncodeRequest_OutputFormat_JSON(t *testing.T) { + maxTokens := 1024 + req := &canonical.CanonicalRequest{ + Model: "claude-3", + Parameters: canonical.RequestParameters{MaxTokens: &maxTokens}, + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + OutputFormat: &canonical.OutputFormat{ + Type: "json_object", + }, + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + oc, ok := result["output_config"].(map[string]any) + require.True(t, ok) + format, ok := oc["format"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "json_schema", format["type"]) + schemaMap, ok := format["schema"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "object", schemaMap["type"]) +} + +func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) { + maxTokens := 1024 + req := &canonical.CanonicalRequest{ + Model: "claude-3", + Parameters: canonical.RequestParameters{MaxTokens: &maxTokens}, + Messages: []canonical.CanonicalMessage{ + {Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("A")}}, + {Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("B")}}, + }, + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + msgs := result["messages"].([]any) + assert.Len(t, msgs, 1) + userMsg := msgs[0].(map[string]any) + assert.Equal(t, "user", userMsg["role"]) + content := userMsg["content"].([]any) + assert.Len(t, content, 2) +} + +func TestEncodeResponse_ContentFilter(t *testing.T) { + sr := canonical.StopReasonContentFilter + resp := &canonical.CanonicalResponse{ + ID: "msg-cf", + Model: "claude-3", + Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}, + StopReason: &sr, + } + + body, err := encodeResponse(resp) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + assert.Equal(t, "end_turn", result["stop_reason"]) +} + +func TestEncodeResponse_ReasoningTokens(t *testing.T) { + reasoning := 100 + sr := canonical.StopReasonEndTurn + resp := &canonical.CanonicalResponse{ + ID: "msg-rt", + Model: "claude-3", + Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}, + StopReason: &sr, + Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5, ReasoningTokens: &reasoning}, + } + + body, err := encodeResponse(resp) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + usage := result["usage"].(map[string]any) + _, hasReasoning := usage["reasoning_tokens"] + assert.False(t, hasReasoning) +} + +func TestEncodeResponse_ToolUse(t *testing.T) { + sr := canonical.StopReasonToolUse + input := json.RawMessage(`{"q":"test"}`) + resp := &canonical.CanonicalResponse{ + ID: "msg-tool", + Model: "claude-3", + Content: []canonical.ContentBlock{canonical.NewToolUseBlock("tool_1", "search", input)}, + StopReason: &sr, + } + + body, err := encodeResponse(resp) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + content := result["content"].([]any) + assert.Len(t, content, 1) + block := content[0].(map[string]any) + assert.Equal(t, "tool_use", block["type"]) + assert.Equal(t, "tool_1", block["id"]) + assert.Equal(t, "search", block["name"]) +} diff --git a/backend/internal/conversion/anthropic/stream_decoder.go b/backend/internal/conversion/anthropic/stream_decoder.go new file mode 100644 index 0000000..543b32a --- /dev/null +++ b/backend/internal/conversion/anthropic/stream_decoder.go @@ -0,0 +1,283 @@ +package anthropic + +import ( + "encoding/json" + "fmt" + "strings" + "unicode/utf8" + + "nex/backend/internal/conversion/canonical" +) + +// StreamDecoder Anthropic 流式解码器 +type StreamDecoder struct { + messageStarted bool + redactedBlocks map[int]bool + utf8Remainder []byte + accumulatedUsage *canonical.CanonicalUsage +} + +// NewStreamDecoder 创建 Anthropic 流式解码器 +func NewStreamDecoder() *StreamDecoder { + return &StreamDecoder{ + redactedBlocks: make(map[int]bool), + } +} + +// ProcessChunk 处理原始 SSE chunk +func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { + data := rawChunk + if len(d.utf8Remainder) > 0 { + data = append(d.utf8Remainder, rawChunk...) + d.utf8Remainder = nil + } + + if !utf8.Valid(data) { + validEnd := len(data) + for !utf8.Valid(data[:validEnd]) { + validEnd-- + } + d.utf8Remainder = append(d.utf8Remainder, data[validEnd:]...) + data = data[:validEnd] + } + + var events []canonical.CanonicalStreamEvent + text := string(data) + + // 解析命名 SSE 事件 + var eventType string + var eventData string + + for _, line := range strings.Split(text, "\n") { + line = strings.TrimRight(line, "\r") + if strings.HasPrefix(line, "event: ") { + eventType = strings.TrimPrefix(line, "event: ") + } else if strings.HasPrefix(line, "data: ") { + eventData = strings.TrimPrefix(line, "data: ") + if eventType != "" && eventData != "" { + chunkEvents := d.processEvent(eventType, []byte(eventData)) + events = append(events, chunkEvents...) + } + eventType = "" + eventData = "" + } else if line == "" { + // SSE 事件分隔符 + } + } + + return events +} + +// Flush 刷新解码器状态 +func (d *StreamDecoder) Flush() []canonical.CanonicalStreamEvent { + return nil +} + +// processEvent 处理单个命名 SSE 事件 +func (d *StreamDecoder) processEvent(eventType string, data []byte) []canonical.CanonicalStreamEvent { + switch eventType { + case "message_start": + return d.processMessageStart(data) + case "content_block_start": + return d.processContentBlockStart(data) + case "content_block_delta": + return d.processContentBlockDelta(data) + case "content_block_stop": + return d.processContentBlockStop(data) + case "message_delta": + return d.processMessageDelta(data) + case "message_stop": + return d.processMessageStop(data) + case "ping": + return []canonical.CanonicalStreamEvent{canonical.NewPingEvent()} + case "error": + return d.processError(data) + } + return nil +} + +// processMessageStart 处理消息开始事件 +func (d *StreamDecoder) processMessageStart(data []byte) []canonical.CanonicalStreamEvent { + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return nil + } + + var msg struct { + ID string `json:"id"` + Model string `json:"model"` + Usage *struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` + } + + if msgRaw, ok := raw["message"]; ok { + if err := json.Unmarshal(msgRaw, &msg); err != nil { + return nil + } + } + + event := canonical.NewMessageStartEvent(msg.ID, msg.Model) + if msg.Usage != nil { + usage := &canonical.CanonicalUsage{ + InputTokens: msg.Usage.InputTokens, + OutputTokens: msg.Usage.OutputTokens, + } + event = canonical.NewMessageStartEventWithUsage(msg.ID, msg.Model, usage) + d.accumulatedUsage = usage + } + + d.messageStarted = true + return []canonical.CanonicalStreamEvent{event} +} + +// processContentBlockStart 处理内容块开始事件 +func (d *StreamDecoder) processContentBlockStart(data []byte) []canonical.CanonicalStreamEvent { + var raw struct { + Index int `json:"index"` + ContentBlock struct { + Type string `json:"type"` + Text string `json:"text"` + ID string `json:"id"` + Name string `json:"name"` + Thinking string `json:"thinking"` + Data string `json:"data"` + } `json:"content_block"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return nil + } + + // 检查需要丢弃的块类型 + switch raw.ContentBlock.Type { + case "redacted_thinking", "server_tool_use", "web_search_tool_result", + "code_execution_tool_result": + d.redactedBlocks[raw.Index] = true + return nil + } + + if d.redactedBlocks[raw.Index] { + return nil + } + + block := canonical.StreamContentBlock{ + Type: raw.ContentBlock.Type, + Text: raw.ContentBlock.Text, + ID: raw.ContentBlock.ID, + Name: raw.ContentBlock.Name, + Thinking: raw.ContentBlock.Thinking, + } + + return []canonical.CanonicalStreamEvent{ + canonical.NewContentBlockStartEvent(raw.Index, block), + } +} + +// processContentBlockDelta 处理内容块增量事件 +func (d *StreamDecoder) processContentBlockDelta(data []byte) []canonical.CanonicalStreamEvent { + var raw struct { + Index int `json:"index"` + Delta struct { + Type string `json:"type"` + Text string `json:"text"` + PartialJSON string `json:"partial_json"` + Thinking string `json:"thinking"` + } `json:"delta"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return nil + } + + // 检查是否在丢弃的块中 + if d.redactedBlocks[raw.Index] { + return nil + } + + // 丢弃协议特有 delta 类型 + switch raw.Delta.Type { + case "citations_delta", "signature_delta": + return nil + } + + delta := canonical.StreamDelta{ + Type: raw.Delta.Type, + Text: raw.Delta.Text, + PartialJSON: raw.Delta.PartialJSON, + Thinking: raw.Delta.Thinking, + } + + return []canonical.CanonicalStreamEvent{ + canonical.NewContentBlockDeltaEvent(raw.Index, delta), + } +} + +// processContentBlockStop 处理内容块结束事件 +func (d *StreamDecoder) processContentBlockStop(data []byte) []canonical.CanonicalStreamEvent { + var raw struct { + Index int `json:"index"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return nil + } + + if _, redacted := d.redactedBlocks[raw.Index]; redacted { + delete(d.redactedBlocks, raw.Index) + return nil + } + + return []canonical.CanonicalStreamEvent{ + canonical.NewContentBlockStopEvent(raw.Index), + } +} + +// processMessageDelta 处理消息增量事件 +func (d *StreamDecoder) processMessageDelta(data []byte) []canonical.CanonicalStreamEvent { + var raw struct { + Delta struct { + StopReason string `json:"stop_reason"` + } `json:"delta"` + Usage struct { + OutputTokens int `json:"output_tokens"` + } `json:"usage"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return nil + } + + sr := mapStopReason(raw.Delta.StopReason) + usage := &canonical.CanonicalUsage{ + OutputTokens: raw.Usage.OutputTokens, + } + + if d.accumulatedUsage != nil { + d.accumulatedUsage.OutputTokens += raw.Usage.OutputTokens + } + + return []canonical.CanonicalStreamEvent{ + canonical.NewMessageDeltaEventWithUsage(sr, usage), + } +} + +// processMessageStop 处理消息结束事件 +func (d *StreamDecoder) processMessageStop(data []byte) []canonical.CanonicalStreamEvent { + return []canonical.CanonicalStreamEvent{canonical.NewMessageStopEvent()} +} + +// processError 处理错误事件 +func (d *StreamDecoder) processError(data []byte) []canonical.CanonicalStreamEvent { + var raw struct { + Error struct { + Type string `json:"type"` + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return []canonical.CanonicalStreamEvent{ + canonical.NewErrorEvent("stream_error", fmt.Sprintf("解析错误事件失败: %s", string(data))), + } + } + return []canonical.CanonicalStreamEvent{ + canonical.NewErrorEvent(raw.Error.Type, raw.Error.Message), + } +} diff --git a/backend/internal/conversion/anthropic/stream_decoder_test.go b/backend/internal/conversion/anthropic/stream_decoder_test.go new file mode 100644 index 0000000..0554621 --- /dev/null +++ b/backend/internal/conversion/anthropic/stream_decoder_test.go @@ -0,0 +1,274 @@ +package anthropic + +import ( + "encoding/json" + "fmt" + "testing" + + "nex/backend/internal/conversion/canonical" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func makeAnthropicEvent(eventType string, data any) []byte { + dataBytes, _ := json.Marshal(data) + return []byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, string(dataBytes))) +} + +func TestStreamDecoder_MessageStart(t *testing.T) { + d := NewStreamDecoder() + + payload := map[string]any{ + "type": "message_start", + "message": map[string]any{ + "id": "msg_1", + "model": "claude-3", + "usage": map[string]any{"input_tokens": 10, "output_tokens": 0}, + }, + } + raw := makeAnthropicEvent("message_start", payload) + + events := d.ProcessChunk(raw) + require.NotEmpty(t, events) + assert.Equal(t, canonical.EventMessageStart, events[0].Type) + assert.Equal(t, "msg_1", events[0].Message.ID) + assert.Equal(t, "claude-3", events[0].Message.Model) +} + +func TestStreamDecoder_ContentBlockDelta(t *testing.T) { + d := NewStreamDecoder() + + tests := []struct { + name string + deltaType string + deltaData map[string]any + checkField string + checkValue string + }{ + { + name: "text_delta", + deltaType: "text_delta", + deltaData: map[string]any{"type": "text_delta", "text": "你好"}, + checkField: "text", + checkValue: "你好", + }, + { + name: "input_json_delta", + deltaType: "input_json_delta", + deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"}, + checkField: "partial_json", + checkValue: "{\"key\":", + }, + { + name: "thinking_delta", + deltaType: "thinking_delta", + deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"}, + checkField: "thinking", + checkValue: "思考中", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload := map[string]any{ + "type": "content_block_delta", + "index": 0, + "delta": tt.deltaData, + } + raw := makeAnthropicEvent("content_block_delta", payload) + + events := d.ProcessChunk(raw) + require.NotEmpty(t, events) + assert.Equal(t, canonical.EventContentBlockDelta, events[0].Type) + assert.NotNil(t, events[0].Delta) + + switch tt.checkField { + case "text": + assert.Equal(t, tt.checkValue, events[0].Delta.Text) + case "partial_json": + assert.Equal(t, tt.checkValue, events[0].Delta.PartialJSON) + case "thinking": + assert.Equal(t, tt.checkValue, events[0].Delta.Thinking) + } + }) + } +} + +func TestStreamDecoder_RedactedThinking(t *testing.T) { + d := NewStreamDecoder() + + // redacted_thinking block start 应被抑制 + payload := map[string]any{ + "type": "content_block_start", + "index": 1, + "content_block": map[string]any{ + "type": "redacted_thinking", + "data": "redacted-data", + }, + } + raw := makeAnthropicEvent("content_block_start", payload) + events := d.ProcessChunk(raw) + assert.Empty(t, events) + assert.True(t, d.redactedBlocks[1]) +} + +func TestStreamDecoder_RedactedBlockStopSuppressed(t *testing.T) { + d := NewStreamDecoder() + d.redactedBlocks[2] = true + + // content_block_stop 对 redacted block 返回 nil + payload := map[string]any{ + "type": "content_block_stop", + "index": 2, + } + raw := makeAnthropicEvent("content_block_stop", payload) + + events := d.ProcessChunk(raw) + assert.Empty(t, events) + // 应清理 redactedBlocks + _, exists := d.redactedBlocks[2] + assert.False(t, exists) +} + +func TestStreamDecoder_ContentBlockStart(t *testing.T) { + d := NewStreamDecoder() + + payload := map[string]any{ + "type": "content_block_start", + "index": 0, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + } + raw := makeAnthropicEvent("content_block_start", payload) + + events := d.ProcessChunk(raw) + require.Len(t, events, 1) + assert.Equal(t, canonical.EventContentBlockStart, events[0].Type) + require.NotNil(t, events[0].ContentBlock) + assert.Equal(t, "text", events[0].ContentBlock.Type) + require.NotNil(t, events[0].Index) + assert.Equal(t, 0, *events[0].Index) +} + +func TestStreamDecoder_ContentBlockStart_ToolUse(t *testing.T) { + d := NewStreamDecoder() + + payload := map[string]any{ + "type": "content_block_start", + "index": 1, + "content_block": map[string]any{ + "type": "tool_use", + "id": "toolu_1", + "name": "search", + }, + } + raw := makeAnthropicEvent("content_block_start", payload) + + events := d.ProcessChunk(raw) + require.Len(t, events, 1) + assert.Equal(t, canonical.EventContentBlockStart, events[0].Type) + require.NotNil(t, events[0].ContentBlock) + assert.Equal(t, "tool_use", events[0].ContentBlock.Type) + assert.Equal(t, "toolu_1", events[0].ContentBlock.ID) + assert.Equal(t, "search", events[0].ContentBlock.Name) +} + +func TestStreamDecoder_ContentBlockStop(t *testing.T) { + d := NewStreamDecoder() + + payload := map[string]any{ + "type": "content_block_stop", + "index": 0, + } + raw := makeAnthropicEvent("content_block_stop", payload) + + events := d.ProcessChunk(raw) + require.Len(t, events, 1) + assert.Equal(t, canonical.EventContentBlockStop, events[0].Type) + require.NotNil(t, events[0].Index) + assert.Equal(t, 0, *events[0].Index) +} + +func TestStreamDecoder_MessageDelta(t *testing.T) { + d := NewStreamDecoder() + + payload := map[string]any{ + "type": "message_delta", + "delta": map[string]any{ + "stop_reason": "end_turn", + }, + "usage": map[string]any{ + "output_tokens": 42, + }, + } + raw := makeAnthropicEvent("message_delta", payload) + + events := d.ProcessChunk(raw) + require.Len(t, events, 1) + assert.Equal(t, canonical.EventMessageDelta, events[0].Type) + require.NotNil(t, events[0].StopReason) + assert.Equal(t, canonical.StopReasonEndTurn, *events[0].StopReason) + require.NotNil(t, events[0].Usage) + assert.Equal(t, 42, events[0].Usage.OutputTokens) +} + +func TestStreamDecoder_MessageStop(t *testing.T) { + d := NewStreamDecoder() + + raw := makeAnthropicEvent("message_stop", map[string]any{"type": "message_stop"}) + + events := d.ProcessChunk(raw) + require.Len(t, events, 1) + assert.Equal(t, canonical.EventMessageStop, events[0].Type) +} + +func TestStreamDecoder_Ping(t *testing.T) { + d := NewStreamDecoder() + + raw := makeAnthropicEvent("ping", map[string]any{"type": "ping"}) + + events := d.ProcessChunk(raw) + require.Len(t, events, 1) + assert.Equal(t, canonical.EventPing, events[0].Type) +} + +func TestStreamDecoder_Error(t *testing.T) { + d := NewStreamDecoder() + + payload := map[string]any{ + "type": "error", + "error": map[string]any{ + "type": "overloaded_error", + "message": "服务过载", + }, + } + raw := makeAnthropicEvent("error", payload) + + events := d.ProcessChunk(raw) + require.Len(t, events, 1) + assert.Equal(t, canonical.EventError, events[0].Type) + require.NotNil(t, events[0].Error) + assert.Equal(t, "overloaded_error", events[0].Error.Type) + assert.Equal(t, "服务过载", events[0].Error.Message) +} + +func TestStreamDecoder_RedactedDeltaSuppressed(t *testing.T) { + d := NewStreamDecoder() + d.redactedBlocks[1] = true + + payload := map[string]any{ + "type": "content_block_delta", + "index": 1, + "delta": map[string]any{ + "type": "text_delta", + "text": "被抑制的内容", + }, + } + raw := makeAnthropicEvent("content_block_delta", payload) + + events := d.ProcessChunk(raw) + assert.Empty(t, events) +} diff --git a/backend/internal/conversion/anthropic/stream_encoder.go b/backend/internal/conversion/anthropic/stream_encoder.go new file mode 100644 index 0000000..002dfe4 --- /dev/null +++ b/backend/internal/conversion/anthropic/stream_encoder.go @@ -0,0 +1,188 @@ +package anthropic + +import ( + "encoding/json" + "fmt" + + "nex/backend/internal/conversion/canonical" +) + +// StreamEncoder Anthropic 流式编码器 +type StreamEncoder struct{} + +// NewStreamEncoder 创建 Anthropic 流式编码器 +func NewStreamEncoder() *StreamEncoder { + return &StreamEncoder{} +} + +// EncodeEvent 编码 Canonical 事件为 Anthropic 命名 SSE 事件 +func (e *StreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { + switch event.Type { + case canonical.EventMessageStart: + return e.encodeMessageStart(event) + case canonical.EventContentBlockStart: + return e.encodeContentBlockStart(event) + case canonical.EventContentBlockDelta: + return e.encodeContentBlockDelta(event) + case canonical.EventContentBlockStop: + return e.encodeContentBlockStop(event) + case canonical.EventMessageDelta: + return e.encodeMessageDelta(event) + case canonical.EventMessageStop: + return e.encodeMessageStop(event) + case canonical.EventPing: + return e.encodePing() + case canonical.EventError: + return e.encodeError(event) + } + return nil +} + +// Flush 刷新缓冲区(无缓冲) +func (e *StreamEncoder) Flush() [][]byte { + return nil +} + +// encodeMessageStart 编码消息开始事件 +func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent) [][]byte { + payload := map[string]any{ + "type": "message_start", + } + if event.Message != nil { + msg := map[string]any{ + "id": event.Message.ID, + "model": event.Message.Model, + "role": "assistant", + } + if event.Message.Usage != nil { + usage := map[string]any{ + "input_tokens": event.Message.Usage.InputTokens, + "output_tokens": event.Message.Usage.OutputTokens, + } + msg["usage"] = usage + } + payload["message"] = msg + } + return e.marshalEvent("message_start", payload) +} + +// encodeContentBlockStart 编码内容块开始事件 +func (e *StreamEncoder) encodeContentBlockStart(event canonical.CanonicalStreamEvent) [][]byte { + if event.ContentBlock == nil || event.Index == nil { + return nil + } + + cb := map[string]any{ + "type": event.ContentBlock.Type, + } + switch event.ContentBlock.Type { + case "text": + cb["text"] = "" + case "tool_use": + cb["id"] = event.ContentBlock.ID + cb["name"] = event.ContentBlock.Name + cb["input"] = map[string]any{} + case "thinking": + cb["thinking"] = "" + } + + payload := map[string]any{ + "type": "content_block_start", + "index": *event.Index, + "content_block": cb, + } + return e.marshalEvent("content_block_start", payload) +} + +// encodeContentBlockDelta 编码内容块增量事件 +func (e *StreamEncoder) encodeContentBlockDelta(event canonical.CanonicalStreamEvent) [][]byte { + if event.Delta == nil || event.Index == nil { + return nil + } + + delta := map[string]any{ + "type": event.Delta.Type, + } + switch canonical.DeltaType(event.Delta.Type) { + case canonical.DeltaTypeText: + delta["text"] = event.Delta.Text + case canonical.DeltaTypeInputJSON: + delta["partial_json"] = event.Delta.PartialJSON + case canonical.DeltaTypeThinking: + delta["thinking"] = event.Delta.Thinking + } + + payload := map[string]any{ + "type": "content_block_delta", + "index": *event.Index, + "delta": delta, + } + return e.marshalEvent("content_block_delta", payload) +} + +// encodeContentBlockStop 编码内容块结束事件 +func (e *StreamEncoder) encodeContentBlockStop(event canonical.CanonicalStreamEvent) [][]byte { + if event.Index == nil { + return nil + } + payload := map[string]any{ + "type": "content_block_stop", + "index": *event.Index, + } + return e.marshalEvent("content_block_stop", payload) +} + +// encodeMessageDelta 编码消息增量事件 +func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent) [][]byte { + delta := map[string]any{} + if event.StopReason != nil { + delta["stop_reason"] = mapCanonicalStopReason(*event.StopReason) + } + + payload := map[string]any{ + "type": "message_delta", + "delta": delta, + } + if event.Usage != nil { + payload["usage"] = map[string]any{ + "output_tokens": event.Usage.OutputTokens, + } + } + return e.marshalEvent("message_delta", payload) +} + +// encodeMessageStop 编码消息结束事件 +func (e *StreamEncoder) encodeMessageStop(event canonical.CanonicalStreamEvent) [][]byte { + payload := map[string]any{"type": "message_stop"} + return e.marshalEvent("message_stop", payload) +} + +// encodePing 编码心跳事件 +func (e *StreamEncoder) encodePing() [][]byte { + payload := map[string]any{"type": "ping"} + return e.marshalEvent("ping", payload) +} + +// encodeError 编码错误事件 +func (e *StreamEncoder) encodeError(event canonical.CanonicalStreamEvent) [][]byte { + if event.Error == nil { + return nil + } + payload := map[string]any{ + "type": "error", + "error": map[string]any{ + "type": event.Error.Type, + "message": event.Error.Message, + }, + } + return e.marshalEvent("error", payload) +} + +// marshalEvent 序列化为 Anthropic 命名 SSE 事件 +func (e *StreamEncoder) marshalEvent(eventType string, payload map[string]any) [][]byte { + data, err := json.Marshal(payload) + if err != nil { + return nil + } + return [][]byte{[]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, data))} +} diff --git a/backend/internal/conversion/anthropic/stream_encoder_test.go b/backend/internal/conversion/anthropic/stream_encoder_test.go new file mode 100644 index 0000000..2cadff4 --- /dev/null +++ b/backend/internal/conversion/anthropic/stream_encoder_test.go @@ -0,0 +1,242 @@ +package anthropic + +import ( + "encoding/json" + "strings" + "testing" + + "nex/backend/internal/conversion/canonical" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStreamEncoder_MessageStart(t *testing.T) { + e := NewStreamEncoder() + event := canonical.NewMessageStartEvent("msg_1", "claude-3") + + chunks := e.EncodeEvent(event) + require.Len(t, chunks, 1) + + s := string(chunks[0]) + assert.True(t, strings.HasPrefix(s, "event: message_start\n")) + assert.Contains(t, s, "data: ") + assert.Contains(t, s, "msg_1") + assert.Contains(t, s, "claude-3") +} + +func TestStreamEncoder_ContentBlockDelta(t *testing.T) { + e := NewStreamEncoder() + event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{Type: "text_delta", Text: "你好"}) + + chunks := e.EncodeEvent(event) + require.Len(t, chunks, 1) + + s := string(chunks[0]) + assert.True(t, strings.HasPrefix(s, "event: content_block_delta\n")) + assert.Contains(t, s, "你好") + + // 验证 JSON 格式 + lines := strings.Split(s, "\n") + var dataLine string + for _, l := range lines { + if strings.HasPrefix(l, "data: ") { + dataLine = strings.TrimPrefix(l, "data: ") + break + } + } + var payload map[string]any + require.NoError(t, json.Unmarshal([]byte(dataLine), &payload)) + assert.Equal(t, "content_block_delta", payload["type"]) +} + +func TestStreamEncoder_MessageStop(t *testing.T) { + e := NewStreamEncoder() + event := canonical.NewMessageStopEvent() + + chunks := e.EncodeEvent(event) + require.Len(t, chunks, 1) + + s := string(chunks[0]) + assert.True(t, strings.HasPrefix(s, "event: message_stop\n")) +} + +func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) { + e := NewStreamEncoder() + event := canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: ""}) + + chunks := e.EncodeEvent(event) + require.Len(t, chunks, 1) + + s := string(chunks[0]) + assert.True(t, strings.HasPrefix(s, "event: content_block_start\n")) + assert.Contains(t, s, "data: ") + + var payload map[string]any + lines := strings.Split(s, "\n") + for _, l := range lines { + if strings.HasPrefix(l, "data: ") { + require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload)) + break + } + } + cb := payload["content_block"].(map[string]any) + assert.Equal(t, "text", cb["type"]) +} + +func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) { + e := NewStreamEncoder() + event := canonical.NewContentBlockStartEvent(1, canonical.StreamContentBlock{ + Type: "tool_use", + ID: "toolu_1", + Name: "search", + }) + + chunks := e.EncodeEvent(event) + require.Len(t, chunks, 1) + + s := string(chunks[0]) + assert.Contains(t, s, "toolu_1") + assert.Contains(t, s, "search") + + var payload map[string]any + lines := strings.Split(s, "\n") + for _, l := range lines { + if strings.HasPrefix(l, "data: ") { + require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload)) + break + } + } + cb := payload["content_block"].(map[string]any) + assert.Equal(t, "tool_use", cb["type"]) + assert.Equal(t, "toolu_1", cb["id"]) + assert.Equal(t, "search", cb["name"]) +} + +func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) { + e := NewStreamEncoder() + event := canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "thinking", Thinking: ""}) + + chunks := e.EncodeEvent(event) + require.Len(t, chunks, 1) + + s := string(chunks[0]) + assert.Contains(t, s, "thinking") + + var payload map[string]any + lines := strings.Split(s, "\n") + for _, l := range lines { + if strings.HasPrefix(l, "data: ") { + require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload)) + break + } + } + cb := payload["content_block"].(map[string]any) + assert.Equal(t, "thinking", cb["type"]) +} + +func TestStreamEncoder_ContentBlockStop(t *testing.T) { + e := NewStreamEncoder() + idx := 2 + event := canonical.CanonicalStreamEvent{ + Type: canonical.EventContentBlockStop, + Index: &idx, + } + + chunks := e.EncodeEvent(event) + require.Len(t, chunks, 1) + + s := string(chunks[0]) + assert.True(t, strings.HasPrefix(s, "event: content_block_stop\n")) + assert.Contains(t, s, "content_block_stop") +} + +func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) { + e := NewStreamEncoder() + sr := canonical.StopReasonEndTurn + event := canonical.CanonicalStreamEvent{ + Type: canonical.EventMessageDelta, + StopReason: &sr, + } + + chunks := e.EncodeEvent(event) + require.Len(t, chunks, 1) + + s := string(chunks[0]) + assert.Contains(t, s, "stop_reason") + + var payload map[string]any + lines := strings.Split(s, "\n") + for _, l := range lines { + if strings.HasPrefix(l, "data: ") { + require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload)) + break + } + } + delta := payload["delta"].(map[string]any) + assert.Equal(t, "end_turn", delta["stop_reason"]) +} + +func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) { + e := NewStreamEncoder() + usage := canonical.CanonicalUsage{OutputTokens: 88} + event := canonical.CanonicalStreamEvent{ + Type: canonical.EventMessageDelta, + Usage: &usage, + } + + chunks := e.EncodeEvent(event) + require.Len(t, chunks, 1) + + s := string(chunks[0]) + assert.Contains(t, s, "output_tokens") + + var payload map[string]any + lines := strings.Split(s, "\n") + for _, l := range lines { + if strings.HasPrefix(l, "data: ") { + require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload)) + break + } + } + u := payload["usage"].(map[string]any) + assert.Equal(t, float64(88), u["output_tokens"]) +} + +func TestStreamEncoder_Ping(t *testing.T) { + e := NewStreamEncoder() + event := canonical.NewPingEvent() + + chunks := e.EncodeEvent(event) + require.Len(t, chunks, 1) + + s := string(chunks[0]) + assert.True(t, strings.HasPrefix(s, "event: ping\n")) + assert.Contains(t, s, "ping") +} + +func TestStreamEncoder_Error(t *testing.T) { + e := NewStreamEncoder() + event := canonical.NewErrorEvent("overloaded_error", "服务过载") + + chunks := e.EncodeEvent(event) + require.Len(t, chunks, 1) + + s := string(chunks[0]) + assert.True(t, strings.HasPrefix(s, "event: error\n")) + assert.Contains(t, s, "overloaded_error") + assert.Contains(t, s, "服务过载") +} + +func TestStreamEncoder_Flush_ReturnsNil(t *testing.T) { + e := NewStreamEncoder() + chunks := e.Flush() + assert.Nil(t, chunks) +} + +func TestStreamEncoder_UnknownEvent_ReturnsNil(t *testing.T) { + e := NewStreamEncoder() + event := canonical.CanonicalStreamEvent{Type: "unknown_event_type"} + chunks := e.EncodeEvent(event) + assert.Nil(t, chunks) +} diff --git a/backend/internal/conversion/anthropic/types.go b/backend/internal/conversion/anthropic/types.go new file mode 100644 index 0000000..f07ccf3 --- /dev/null +++ b/backend/internal/conversion/anthropic/types.go @@ -0,0 +1,183 @@ +package anthropic + +import ( + "encoding/json" +) + +// MessagesRequest Anthropic Messages 请求 +type MessagesRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + System any `json:"system,omitempty"` + MaxTokens int `json:"max_tokens"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + Metadata *RequestMetadata `json:"metadata,omitempty"` + Thinking *ThinkingConfig `json:"thinking,omitempty"` + OutputConfig *OutputConfig `json:"output_config,omitempty"` + DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"` + Container any `json:"container,omitempty"` +} + +// RequestMetadata 请求元数据 +type RequestMetadata struct { + UserID string `json:"user_id,omitempty"` +} + +// ThinkingConfig 思考配置 +type ThinkingConfig struct { + Type string `json:"type"` + BudgetTokens *int `json:"budget_tokens,omitempty"` + Display string `json:"display,omitempty"` +} + +// OutputConfig 输出配置 +type OutputConfig struct { + Format *OutputFormatConfig `json:"format,omitempty"` + Effort string `json:"effort,omitempty"` +} + +// OutputFormatConfig 输出格式配置 +type OutputFormatConfig struct { + Type string `json:"type"` + Schema json.RawMessage `json:"schema,omitempty"` +} + +// Message Anthropic 消息 +type Message struct { + Role string `json:"role"` + Content any `json:"content"` +} + +// TextContent 文本内容块 +type TextContent struct { + Type string `json:"type"` + Text string `json:"text"` +} + +// ToolUseContent 工具调用内容块 +type ToolUseContent struct { + Type string `json:"type"` + ID string `json:"id"` + Name string `json:"name"` + Input json.RawMessage `json:"input"` +} + +// ToolResultContent 工具结果内容块 +type ToolResultContent struct { + Type string `json:"type"` + ToolUseID string `json:"tool_use_id"` + Content any `json:"content"` + IsError *bool `json:"is_error,omitempty"` +} + +// ThinkingContent 思考内容块 +type ThinkingContent struct { + Type string `json:"type"` + Thinking string `json:"thinking"` +} + +// RedactedThinkingContent 已编辑思考内容块 +type RedactedThinkingContent struct { + Type string `json:"type"` + Data string `json:"data"` +} + +// Tool Anthropic 工具定义 +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema json.RawMessage `json:"input_schema"` +} + +// MessagesResponse Anthropic Messages 响应 +type MessagesResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Model string `json:"model"` + Content []ContentBlock `json:"content"` + StopReason string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence,omitempty"` + StopDetails any `json:"stop_details,omitempty"` + Container any `json:"container,omitempty"` + Usage ResponseUsage `json:"usage"` +} + +// ContentBlock Anthropic 响应内容块 +type ContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input json.RawMessage `json:"input,omitempty"` + Thinking string `json:"thinking,omitempty"` + Data string `json:"data,omitempty"` +} + +// ResponseUsage 响应用量 +type ResponseUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"` + CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"` +} + +// ModelsResponse Anthropic 模型列表响应 +type ModelsResponse struct { + Data []ModelItem `json:"data"` + HasMore bool `json:"has_more"` + FirstID *string `json:"first_id,omitempty"` + LastID *string `json:"last_id,omitempty"` +} + +// ModelItem Anthropic 模型项 +type ModelItem struct { + ID string `json:"id"` + Type string `json:"type"` + DisplayName string `json:"display_name,omitempty"` + CreatedAt string `json:"created_at,omitempty"` +} + +// ModelInfoResponse Anthropic 模型详情响应 +type ModelInfoResponse struct { + ID string `json:"id"` + Type string `json:"type"` + DisplayName string `json:"display_name,omitempty"` + CreatedAt string `json:"created_at,omitempty"` +} + +// EmbeddingRequest Anthropic 不支持嵌入,但定义类型用于接口兼容 +type EmbeddingRequest struct{} + +// EmbeddingResponse Anthropic 不支持嵌入 +type EmbeddingResponse struct{} + +// RerankRequest Anthropic 不支持重排序 +type RerankRequest struct{} + +// RerankResponse Anthropic 不支持重排序 +type RerankResponse struct{} + +// ErrorResponse Anthropic 错误响应 +type ErrorResponse struct { + Type string `json:"type"` + Error ErrorDetail `json:"error"` +} + +// ErrorDetail 错误详情 +type ErrorDetail struct { + Type string `json:"type"` + Message string `json:"message"` +} + +// SSEEvent SSE 事件 +type SSEEvent struct { + EventType string + Data json.RawMessage +} diff --git a/backend/internal/conversion/canonical/extended.go b/backend/internal/conversion/canonical/extended.go new file mode 100644 index 0000000..2c8fdd0 --- /dev/null +++ b/backend/internal/conversion/canonical/extended.go @@ -0,0 +1,71 @@ +package canonical + +// CanonicalModel 规范模型 +type CanonicalModel struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + Created int64 `json:"created,omitempty"` + OwnedBy string `json:"owned_by,omitempty"` +} + +// CanonicalModelList 规范模型列表 +type CanonicalModelList struct { + Models []CanonicalModel `json:"models"` +} + +// CanonicalModelInfo 规范模型详情 +type CanonicalModelInfo struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + Created int64 `json:"created,omitempty"` + OwnedBy string `json:"owned_by,omitempty"` +} + +// CanonicalEmbeddingRequest 规范嵌入请求 +type CanonicalEmbeddingRequest struct { + Model string `json:"model"` + Input any `json:"input"` // string 或 []string + EncodingFormat string `json:"encoding_format,omitempty"` + Dimensions *int `json:"dimensions,omitempty"` +} + +// CanonicalEmbeddingResponse 规范嵌入响应 +type CanonicalEmbeddingResponse struct { + Data []EmbeddingData `json:"data"` + Model string `json:"model"` + Usage EmbeddingUsage `json:"usage"` +} + +// EmbeddingData 嵌入数据项 +type EmbeddingData struct { + Index int `json:"index"` + Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串 +} + +// EmbeddingUsage 嵌入用量 +type EmbeddingUsage struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// CanonicalRerankRequest 规范重排序请求 +type CanonicalRerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + Documents []string `json:"documents"` + TopN *int `json:"top_n,omitempty"` + ReturnDocuments *bool `json:"return_documents,omitempty"` +} + +// CanonicalRerankResponse 规范重排序响应 +type CanonicalRerankResponse struct { + Results []RerankResult `json:"results"` + Model string `json:"model"` +} + +// RerankResult 重排序结果项 +type RerankResult struct { + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` + Document *string `json:"document,omitempty"` +} diff --git a/backend/internal/conversion/canonical/stream.go b/backend/internal/conversion/canonical/stream.go new file mode 100644 index 0000000..af179e4 --- /dev/null +++ b/backend/internal/conversion/canonical/stream.go @@ -0,0 +1,156 @@ +package canonical + +// StreamEventType 流式事件类型枚举 +type StreamEventType string + +const ( + EventMessageStart StreamEventType = "message_start" + EventContentBlockStart StreamEventType = "content_block_start" + EventContentBlockDelta StreamEventType = "content_block_delta" + EventContentBlockStop StreamEventType = "content_block_stop" + EventMessageDelta StreamEventType = "message_delta" + EventMessageStop StreamEventType = "message_stop" + EventError StreamEventType = "error" + EventPing StreamEventType = "ping" +) + +// DeltaType 增量类型枚举 +type DeltaType string + +const ( + DeltaTypeText DeltaType = "text_delta" + DeltaTypeInputJSON DeltaType = "input_json_delta" + DeltaTypeThinking DeltaType = "thinking_delta" +) + +// StreamDelta 流式增量联合体 +type StreamDelta struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + PartialJSON string `json:"partial_json,omitempty"` + Thinking string `json:"thinking,omitempty"` +} + +// StreamContentBlock 流式内容块联合体 +type StreamContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Thinking string `json:"thinking,omitempty"` +} + +// CanonicalStreamEvent 规范流式事件联合体 +type CanonicalStreamEvent struct { + Type StreamEventType `json:"type"` + + // MessageStartEvent + Message *StreamMessage `json:"message,omitempty"` + + // ContentBlockStartEvent / ContentBlockDeltaEvent / ContentBlockStopEvent + Index *int `json:"index,omitempty"` + ContentBlock *StreamContentBlock `json:"content_block,omitempty"` + Delta *StreamDelta `json:"delta,omitempty"` + + // MessageDeltaEvent + StopReason *StopReason `json:"stop_reason,omitempty"` + Usage *CanonicalUsage `json:"usage,omitempty"` + + // ErrorEvent + Error *StreamError `json:"error,omitempty"` +} + +// StreamMessage 流式消息摘要 +type StreamMessage struct { + ID string `json:"id"` + Model string `json:"model"` + Usage *CanonicalUsage `json:"usage,omitempty"` +} + +// StreamError 流式错误 +type StreamError struct { + Type string `json:"type"` + Message string `json:"message"` +} + +// NewMessageStartEvent 创建消息开始事件 +func NewMessageStartEvent(id, model string) CanonicalStreamEvent { + return CanonicalStreamEvent{ + Type: EventMessageStart, + Message: &StreamMessage{ID: id, Model: model}, + } +} + +// NewMessageStartEventWithUsage 创建带用量的消息开始事件 +func NewMessageStartEventWithUsage(id, model string, usage *CanonicalUsage) CanonicalStreamEvent { + return CanonicalStreamEvent{ + Type: EventMessageStart, + Message: &StreamMessage{ID: id, Model: model, Usage: usage}, + } +} + +// NewContentBlockStartEvent 创建内容块开始事件 +func NewContentBlockStartEvent(index int, block StreamContentBlock) CanonicalStreamEvent { + idx := index + return CanonicalStreamEvent{ + Type: EventContentBlockStart, + Index: &idx, + ContentBlock: &block, + } +} + +// NewContentBlockDeltaEvent 创建内容块增量事件 +func NewContentBlockDeltaEvent(index int, delta StreamDelta) CanonicalStreamEvent { + idx := index + return CanonicalStreamEvent{ + Type: EventContentBlockDelta, + Index: &idx, + Delta: &delta, + } +} + +// NewContentBlockStopEvent 创建内容块结束事件 +func NewContentBlockStopEvent(index int) CanonicalStreamEvent { + idx := index + return CanonicalStreamEvent{ + Type: EventContentBlockStop, + Index: &idx, + } +} + +// NewMessageDeltaEvent 创建消息增量事件 +func NewMessageDeltaEvent(stopReason StopReason) CanonicalStreamEvent { + sr := stopReason + return CanonicalStreamEvent{ + Type: EventMessageDelta, + StopReason: &sr, + } +} + +// NewMessageDeltaEventWithUsage 创建带用量的消息增量事件 +func NewMessageDeltaEventWithUsage(stopReason StopReason, usage *CanonicalUsage) CanonicalStreamEvent { + sr := stopReason + return CanonicalStreamEvent{ + Type: EventMessageDelta, + StopReason: &sr, + Usage: usage, + } +} + +// NewMessageStopEvent 创建消息结束事件 +func NewMessageStopEvent() CanonicalStreamEvent { + return CanonicalStreamEvent{Type: EventMessageStop} +} + +// NewErrorEvent 创建错误事件 +func NewErrorEvent(errType, message string) CanonicalStreamEvent { + return CanonicalStreamEvent{ + Type: EventError, + Error: &StreamError{Type: errType, Message: message}, + } +} + +// NewPingEvent 创建心跳事件 +func NewPingEvent() CanonicalStreamEvent { + return CanonicalStreamEvent{Type: EventPing} +} diff --git a/backend/internal/conversion/canonical/types.go b/backend/internal/conversion/canonical/types.go new file mode 100644 index 0000000..8d9de7a --- /dev/null +++ b/backend/internal/conversion/canonical/types.go @@ -0,0 +1,208 @@ +package canonical + +import ( + "encoding/json" + "fmt" +) + +// MessageRole 消息角色枚举 +type MessageRole string + +const ( + RoleSystem MessageRole = "system" + RoleUser MessageRole = "user" + RoleAssistant MessageRole = "assistant" + RoleTool MessageRole = "tool" +) + +// StopReason 停止原因枚举 +type StopReason string + +const ( + StopReasonEndTurn StopReason = "end_turn" + StopReasonMaxTokens StopReason = "max_tokens" + StopReasonToolUse StopReason = "tool_use" + StopReasonStopSequence StopReason = "stop_sequence" + StopReasonContentFilter StopReason = "content_filter" + StopReasonRefusal StopReason = "refusal" +) + +// SystemBlock 系统消息块 +type SystemBlock struct { + Text string `json:"text"` +} + +// ContentBlock 使用 type 字段的 discriminated union +type ContentBlock struct { + Type string `json:"type"` + + // TextBlock + Text string `json:"text,omitempty"` + + // ToolUseBlock + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input json.RawMessage `json:"input,omitempty"` + + // ToolResultBlock + ToolUseID string `json:"tool_use_id,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + IsError *bool `json:"is_error,omitempty"` + + // ThinkingBlock + Thinking string `json:"thinking,omitempty"` +} + +// NewTextBlock 创建文本块 +func NewTextBlock(text string) ContentBlock { + return ContentBlock{Type: "text", Text: text} +} + +// NewToolUseBlock 创建工具调用块 +func NewToolUseBlock(id, name string, input json.RawMessage) ContentBlock { + return ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input} +} + +// NewToolResultBlock 创建工具结果块 +func NewToolResultBlock(toolUseID string, content string, isError bool) ContentBlock { + errFlag := &isError + return ContentBlock{ + Type: "tool_result", + ToolUseID: toolUseID, + Content: json.RawMessage(fmt.Sprintf("%q", content)), + IsError: errFlag, + } +} + +// NewThinkingBlock 创建思考块 +func NewThinkingBlock(thinking string) ContentBlock { + return ContentBlock{Type: "thinking", Thinking: thinking} +} + +// CanonicalMessage 规范消息 +type CanonicalMessage struct { + Role MessageRole `json:"role"` + Content []ContentBlock `json:"content"` +} + +// CanonicalTool 规范工具定义 +type CanonicalTool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema json.RawMessage `json:"input_schema"` +} + +// ToolChoice 工具选择联合体 +type ToolChoice struct { + Type string `json:"type"` + Name string `json:"name,omitempty"` +} + +// NewToolChoiceAuto 创建自动工具选择 +func NewToolChoiceAuto() *ToolChoice { + return &ToolChoice{Type: "auto"} +} + +// NewToolChoiceNone 创建无工具选择 +func NewToolChoiceNone() *ToolChoice { + return &ToolChoice{Type: "none"} +} + +// NewToolChoiceAny 创建任意工具选择 +func NewToolChoiceAny() *ToolChoice { + return &ToolChoice{Type: "any"} +} + +// NewToolChoiceNamed 创建指定工具选择 +func NewToolChoiceNamed(name string) *ToolChoice { + return &ToolChoice{Type: "tool", Name: name} +} + +// RequestParameters 请求参数 +type RequestParameters struct { + MaxTokens *int `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` +} + +// ThinkingConfig 思考配置 +type ThinkingConfig struct { + Type string `json:"type"` + BudgetTokens *int `json:"budget_tokens,omitempty"` + Effort string `json:"effort,omitempty"` +} + +// OutputFormat 输出格式联合体 +type OutputFormat struct { + Type string `json:"type"` + Name string `json:"name,omitempty"` + Schema json.RawMessage `json:"schema,omitempty"` + Strict *bool `json:"strict,omitempty"` +} + +// CanonicalRequest 规范请求 +type CanonicalRequest struct { + Model string `json:"model"` + System any `json:"system,omitempty"` // nil, string, or []SystemBlock + Messages []CanonicalMessage `json:"messages"` + Tools []CanonicalTool `json:"tools,omitempty"` + ToolChoice *ToolChoice `json:"tool_choice,omitempty"` + Parameters RequestParameters `json:"parameters"` + Thinking *ThinkingConfig `json:"thinking,omitempty"` + Stream bool `json:"stream"` + UserID string `json:"user_id,omitempty"` + OutputFormat *OutputFormat `json:"output_format,omitempty"` + ParallelToolUse *bool `json:"parallel_tool_use,omitempty"` +} + +// CanonicalUsage 规范用量 +type CanonicalUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheReadTokens *int `json:"cache_read_tokens,omitempty"` + CacheCreationTokens *int `json:"cache_creation_tokens,omitempty"` + ReasoningTokens *int `json:"reasoning_tokens,omitempty"` +} + +// CanonicalResponse 规范响应 +type CanonicalResponse struct { + ID string `json:"id"` + Model string `json:"model"` + Content []ContentBlock `json:"content"` + StopReason *StopReason `json:"stop_reason,omitempty"` + Usage CanonicalUsage `json:"usage"` +} + +// GetSystemString 获取系统消息字符串 +func (r *CanonicalRequest) GetSystemString() string { + switch v := r.System.(type) { + case string: + return v + case []SystemBlock: + var result string + for i, b := range v { + if i > 0 { + result += "\n\n" + } + result += b.Text + } + return result + case nil: + return "" + default: + return fmt.Sprintf("%v", v) + } +} + +// SetSystemString 设置系统消息字符串 +func (r *CanonicalRequest) SetSystemString(s string) { + if s == "" { + r.System = nil + } else { + r.System = s + } +} diff --git a/backend/internal/conversion/engine.go b/backend/internal/conversion/engine.go new file mode 100644 index 0000000..4e6bfa0 --- /dev/null +++ b/backend/internal/conversion/engine.go @@ -0,0 +1,338 @@ +package conversion + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// HTTPRequestSpec HTTP 请求规格 +type HTTPRequestSpec struct { + URL string `json:"url"` + Method string `json:"method"` + Headers map[string]string `json:"headers"` + Body []byte `json:"body"` +} + +// HTTPResponseSpec HTTP 响应规格 +type HTTPResponseSpec struct { + StatusCode int `json:"status_code"` + Headers map[string]string `json:"headers"` + Body []byte `json:"body"` +} + +// ConversionEngine 转换引擎门面 +type ConversionEngine struct { + registry AdapterRegistry + middlewareChain *MiddlewareChain +} + +// NewConversionEngine 创建转换引擎 +func NewConversionEngine(registry AdapterRegistry) *ConversionEngine { + return &ConversionEngine{ + registry: registry, + middlewareChain: NewMiddlewareChain(), + } +} + +// RegisterAdapter 注册协议适配器 +func (e *ConversionEngine) RegisterAdapter(adapter ProtocolAdapter) error { + return e.registry.Register(adapter) +} + +// GetRegistry 返回注册表(供外部使用) +func (e *ConversionEngine) GetRegistry() AdapterRegistry { + return e.registry +} + +// Use 添加中间件 +func (e *ConversionEngine) Use(mw ConversionMiddleware) { + e.middlewareChain.Use(mw) +} + +// IsPassthrough 判断是否同协议透传 +func (e *ConversionEngine) IsPassthrough(clientProtocol, providerProtocol string) bool { + if clientProtocol != providerProtocol { + return false + } + adapter, err := e.registry.Get(clientProtocol) + if err != nil { + return false + } + return adapter.SupportsPassthrough() +} + +// ConvertHttpRequest 转换 HTTP 请求 +func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) { + nativePath := spec.URL + + if e.IsPassthrough(clientProtocol, providerProtocol) { + providerAdapter, err := e.registry.Get(providerProtocol) + if err != nil { + return nil, err + } + return &HTTPRequestSpec{ + URL: provider.BaseURL + nativePath, + Method: spec.Method, + Headers: providerAdapter.BuildHeaders(provider), + Body: spec.Body, + }, nil + } + + clientAdapter, err := e.registry.Get(clientProtocol) + if err != nil { + return nil, fmt.Errorf("未找到客户端适配器 %s: %w", clientProtocol, err) + } + providerAdapter, err := e.registry.Get(providerProtocol) + if err != nil { + return nil, fmt.Errorf("未找到服务端适配器 %s: %w", providerProtocol, err) + } + + interfaceType := clientAdapter.DetectInterfaceType(nativePath) + providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType) + providerHeaders := providerAdapter.BuildHeaders(provider) + providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body) + if err != nil { + return nil, err + } + + return &HTTPRequestSpec{ + URL: provider.BaseURL + providerUrl, + Method: spec.Method, + Headers: providerHeaders, + Body: providerBody, + }, nil +} + +// ConvertHttpResponse 转换 HTTP 响应 +func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType) (*HTTPResponseSpec, error) { + if e.IsPassthrough(clientProtocol, providerProtocol) { + return &spec, nil + } + + clientAdapter, err := e.registry.Get(clientProtocol) + if err != nil { + return nil, err + } + providerAdapter, err := e.registry.Get(providerProtocol) + if err != nil { + return nil, err + } + + convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body) + if err != nil { + return nil, err + } + + return &HTTPResponseSpec{ + StatusCode: spec.StatusCode, + Headers: spec.Headers, + Body: convertedBody, + }, nil +} + +// CreateStreamConverter 创建流式转换器 +func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string) (StreamConverter, error) { + if e.IsPassthrough(clientProtocol, providerProtocol) { + return NewPassthroughStreamConverter(), nil + } + + providerAdapter, err := e.registry.Get(providerProtocol) + if err != nil { + return nil, err + } + clientAdapter, err := e.registry.Get(clientProtocol) + if err != nil { + return nil, err + } + + ctx := ConversionContext{ + ConversionID: uuid.New().String(), + InterfaceType: InterfaceTypeChat, + Timestamp: time.Now(), + } + + return NewCanonicalStreamConverterWithMiddleware( + providerAdapter.CreateStreamDecoder(), + clientAdapter.CreateStreamEncoder(), + e.middlewareChain, + ctx, + clientProtocol, + providerProtocol, + ), nil +} + +// convertBody 转换请求体 +func (e *ConversionEngine) convertBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { + switch interfaceType { + case InterfaceTypeChat: + return e.convertChatBody(clientAdapter, providerAdapter, provider, body) + case InterfaceTypeModels, InterfaceTypeModelInfo: + return body, nil + case InterfaceTypeEmbeddings: + if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) { + return body, nil + } + return e.convertEmbeddingBody(clientAdapter, providerAdapter, provider, body) + case InterfaceTypeRerank: + if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) { + return body, nil + } + return e.convertRerankBody(clientAdapter, providerAdapter, provider, body) + default: + return body, nil + } +} + +// convertResponseBody 转换响应体 +func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { + switch interfaceType { + case InterfaceTypeChat: + return e.convertChatResponseBody(clientAdapter, providerAdapter, body) + case InterfaceTypeModels: + if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) { + return body, nil + } + return e.convertModelsResponseBody(clientAdapter, providerAdapter, body) + case InterfaceTypeModelInfo: + if !clientAdapter.SupportsInterface(InterfaceTypeModelInfo) || !providerAdapter.SupportsInterface(InterfaceTypeModelInfo) { + return body, nil + } + return e.convertModelInfoResponseBody(clientAdapter, providerAdapter, body) + case InterfaceTypeEmbeddings: + if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) { + return body, nil + } + return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body) + case InterfaceTypeRerank: + if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) { + return body, nil + } + return e.convertRerankResponseBody(clientAdapter, providerAdapter, body) + default: + return body, nil + } +} + +func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { + canonicalReq, err := clientAdapter.DecodeRequest(body) + if err != nil { + return nil, NewConversionError(ErrorCodeJSONParseError, "解码请求失败").WithCause(err) + } + + ctx := NewConversionContext(InterfaceTypeChat) + canonicalReq, err = e.middlewareChain.Apply(canonicalReq, clientAdapter.ProtocolName(), providerAdapter.ProtocolName(), ctx) + if err != nil { + return nil, err + } + + encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider) + if err != nil { + return nil, NewConversionError(ErrorCodeEncodingFailure, "编码请求失败").WithCause(err) + } + return encoded, nil +} + +func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { + canonicalResp, err := providerAdapter.DecodeResponse(body) + if err != nil { + return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err) + } + encoded, err := clientAdapter.EncodeResponse(canonicalResp) + if err != nil { + return nil, NewConversionError(ErrorCodeEncodingFailure, "编码响应失败").WithCause(err) + } + return encoded, nil +} + +func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { + models, err := providerAdapter.DecodeModelsResponse(body) + if err != nil { + zap.L().Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error())) + return body, nil + } + encoded, err := clientAdapter.EncodeModelsResponse(models) + if err != nil { + zap.L().Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error())) + return body, nil + } + return encoded, nil +} + +func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { + info, err := providerAdapter.DecodeModelInfoResponse(body) + if err != nil { + zap.L().Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) + return body, nil + } + encoded, err := clientAdapter.EncodeModelInfoResponse(info) + if err != nil { + zap.L().Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) + return body, nil + } + return encoded, nil +} + +func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { + req, err := clientAdapter.DecodeEmbeddingRequest(body) + if err != nil { + zap.L().Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error())) + return body, nil + } + return providerAdapter.EncodeEmbeddingRequest(req, provider) +} + +func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { + resp, err := providerAdapter.DecodeEmbeddingResponse(body) + if err != nil { + zap.L().Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error())) + return body, nil + } + return clientAdapter.EncodeEmbeddingResponse(resp) +} + +func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { + req, err := clientAdapter.DecodeRerankRequest(body) + if err != nil { + zap.L().Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error())) + return body, nil + } + return providerAdapter.EncodeRerankRequest(req, provider) +} + +func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { + resp, err := providerAdapter.DecodeRerankResponse(body) + if err != nil { + return body, nil + } + return clientAdapter.EncodeRerankResponse(resp) +} + +// DetectInterfaceType 检测接口类型 +func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string) (InterfaceType, error) { + adapter, err := e.registry.Get(clientProtocol) + if err != nil { + return InterfaceTypePassthrough, err + } + return adapter.DetectInterfaceType(nativePath), nil +} + +// EncodeError 使用客户端适配器编码错误 +func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol string) ([]byte, int, error) { + adapter, adapterErr := e.registry.Get(clientProtocol) + if adapterErr != nil { + fallback := map[string]any{ + "error": map[string]string{ + "message": err.Error(), + "type": "internal_error", + }, + } + body, _ := json.Marshal(fallback) + return body, 500, nil + } + body, statusCode := adapter.EncodeError(err) + return body, statusCode, nil +} diff --git a/backend/internal/conversion/engine_test.go b/backend/internal/conversion/engine_test.go new file mode 100644 index 0000000..3f13c07 --- /dev/null +++ b/backend/internal/conversion/engine_test.go @@ -0,0 +1,366 @@ +package conversion + +import ( + "encoding/json" + "testing" + + "nex/backend/internal/conversion/canonical" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockProtocolAdapter 模拟协议适配器 +type mockProtocolAdapter struct { + protocolName string + passthrough bool + ifaceType InterfaceType + supportsIface map[InterfaceType]bool + decodeReqFn func([]byte) (*canonical.CanonicalRequest, error) + encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error) + decodeRespFn func([]byte) (*canonical.CanonicalResponse, error) + encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error) + streamDecoderFn func() StreamDecoder + streamEncoderFn func() StreamEncoder +} + +func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter { + return &mockProtocolAdapter{ + protocolName: name, + passthrough: passthrough, + ifaceType: InterfaceTypeChat, + supportsIface: map[InterfaceType]bool{}, + } +} + +func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName } +func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" } +func (m *mockProtocolAdapter) SupportsPassthrough() bool { return m.passthrough } + +func (m *mockProtocolAdapter) DetectInterfaceType(nativePath string) InterfaceType { + return m.ifaceType +} + +func (m *mockProtocolAdapter) BuildUrl(nativePath string, interfaceType InterfaceType) string { + return nativePath +} + +func (m *mockProtocolAdapter) BuildHeaders(provider *TargetProvider) map[string]string { + return map[string]string{"Authorization": "Bearer " + provider.APIKey} +} + +func (m *mockProtocolAdapter) SupportsInterface(interfaceType InterfaceType) bool { + if v, ok := m.supportsIface[interfaceType]; ok { + return v + } + return interfaceType == InterfaceTypeChat +} + +func (m *mockProtocolAdapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) { + if m.decodeReqFn != nil { + return m.decodeReqFn(raw) + } + req := &canonical.CanonicalRequest{} + _ = json.Unmarshal(raw, req) + return req, nil +} + +func (m *mockProtocolAdapter) EncodeRequest(req *canonical.CanonicalRequest, provider *TargetProvider) ([]byte, error) { + if m.encodeReqFn != nil { + return m.encodeReqFn(req, provider) + } + return json.Marshal(req) +} + +func (m *mockProtocolAdapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) { + if m.decodeRespFn != nil { + return m.decodeRespFn(raw) + } + resp := &canonical.CanonicalResponse{} + _ = json.Unmarshal(raw, resp) + return resp, nil +} + +func (m *mockProtocolAdapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) { + if m.encodeRespFn != nil { + return m.encodeRespFn(resp) + } + return json.Marshal(resp) +} + +func (m *mockProtocolAdapter) CreateStreamDecoder() StreamDecoder { + if m.streamDecoderFn != nil { + return m.streamDecoderFn() + } + return &noopStreamDecoder{} +} + +func (m *mockProtocolAdapter) CreateStreamEncoder() StreamEncoder { + if m.streamEncoderFn != nil { + return m.streamEncoderFn() + } + return &noopStreamEncoder{} +} + +func (m *mockProtocolAdapter) EncodeError(err *ConversionError) ([]byte, int) { + return []byte(`{"error":"mock"}`), 400 +} + +func (m *mockProtocolAdapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) { + return &canonical.CanonicalModelList{}, nil +} + +func (m *mockProtocolAdapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) { + return json.Marshal(list) +} + +func (m *mockProtocolAdapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) { + return &canonical.CanonicalModelInfo{}, nil +} + +func (m *mockProtocolAdapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) { + return json.Marshal(info) +} + +func (m *mockProtocolAdapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) { + return &canonical.CanonicalEmbeddingRequest{}, nil +} + +func (m *mockProtocolAdapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *TargetProvider) ([]byte, error) { + return json.Marshal(req) +} + +func (m *mockProtocolAdapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) { + return &canonical.CanonicalEmbeddingResponse{}, nil +} + +func (m *mockProtocolAdapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) { + return json.Marshal(resp) +} + +func (m *mockProtocolAdapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) { + return &canonical.CanonicalRerankRequest{}, nil +} + +func (m *mockProtocolAdapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error) { + return json.Marshal(req) +} + +func (m *mockProtocolAdapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) { + return &canonical.CanonicalRerankResponse{}, nil +} + +func (m *mockProtocolAdapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) { + return json.Marshal(resp) +} + +// noopStreamDecoder 空流式解码器 +type noopStreamDecoder struct{} + +func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil } +func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil } + +// noopStreamEncoder 空流式编码器 +type noopStreamEncoder struct{} + +func (e *noopStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { return nil } +func (e *noopStreamEncoder) Flush() [][]byte { return nil } + +// ============ 测试用例 ============ + +func TestNewConversionEngine(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry) + assert.NotNil(t, engine) + assert.Equal(t, registry, engine.GetRegistry()) +} + +func TestRegisterAdapter(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry) + + adapter := newMockAdapter("test-proto", true) + err := engine.RegisterAdapter(adapter) + require.NoError(t, err) + + protocols := registry.ListProtocols() + assert.Contains(t, protocols, "test-proto") +} + +func TestIsPassthrough_SameProtocol(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry) + adapter := newMockAdapter("openai", true) + _ = engine.RegisterAdapter(adapter) + + assert.True(t, engine.IsPassthrough("openai", "openai")) +} + +func TestIsPassthrough_DifferentProtocol(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry) + _ = engine.RegisterAdapter(newMockAdapter("openai", true)) + _ = engine.RegisterAdapter(newMockAdapter("anthropic", true)) + + assert.False(t, engine.IsPassthrough("openai", "anthropic")) +} + +func TestIsPassthrough_NoPassthrough(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry) + _ = engine.RegisterAdapter(newMockAdapter("custom", false)) + + assert.False(t, engine.IsPassthrough("custom", "custom")) +} + +func TestDetectInterfaceType(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry) + adapter := newMockAdapter("test", true) + adapter.ifaceType = InterfaceTypeChat + _ = engine.RegisterAdapter(adapter) + + ifaceType, err := engine.DetectInterfaceType("/v1/chat/completions", "test") + require.NoError(t, err) + assert.Equal(t, InterfaceTypeChat, ifaceType) +} + +func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry) + + _, err := engine.DetectInterfaceType("/v1/chat", "nonexistent") + assert.Error(t, err) +} + +func TestConvertHttpRequest_Passthrough(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry) + _ = engine.RegisterAdapter(newMockAdapter("openai", true)) + + provider := NewTargetProvider("https://api.openai.com", "sk-test", "gpt-4") + spec := HTTPRequestSpec{ + URL: "/v1/chat/completions", + Method: "POST", + Body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`), + } + + result, err := engine.ConvertHttpRequest(spec, "openai", "openai", provider) + require.NoError(t, err) + assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL) + assert.Equal(t, spec.Body, result.Body) +} + +func TestConvertHttpRequest_CrossProtocol(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry) + + clientAdapter := newMockAdapter("client-proto", false) + clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) { + return &canonical.CanonicalRequest{ + Model: "test-model", + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + }, nil + } + _ = engine.RegisterAdapter(clientAdapter) + + providerAdapter := newMockAdapter("provider-proto", false) + providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) { + return json.Marshal(map[string]any{"model": p.ModelName}) + } + _ = engine.RegisterAdapter(providerAdapter) + + provider := NewTargetProvider("https://example.com", "key", "my-model") + spec := HTTPRequestSpec{ + URL: "/v1/chat", + Method: "POST", + Body: []byte(`{"model":"test"}`), + } + + result, err := engine.ConvertHttpRequest(spec, "client-proto", "provider-proto", provider) + require.NoError(t, err) + assert.Contains(t, result.URL, "https://example.com") + assert.NotNil(t, result.Body) +} + +func TestConvertHttpResponse_Passthrough(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry) + _ = engine.RegisterAdapter(newMockAdapter("openai", true)) + + spec := HTTPResponseSpec{ + StatusCode: 200, + Body: []byte(`{"id":"123"}`), + } + + result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat) + require.NoError(t, err) + assert.Equal(t, 200, result.StatusCode) + assert.Equal(t, spec.Body, result.Body) +} + +func TestCreateStreamConverter_Passthrough(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry) + _ = engine.RegisterAdapter(newMockAdapter("openai", true)) + + converter, err := engine.CreateStreamConverter("openai", "openai") + require.NoError(t, err) + _, ok := converter.(*PassthroughStreamConverter) + assert.True(t, ok) +} + +func TestCreateStreamConverter_Canonical(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry) + _ = engine.RegisterAdapter(newMockAdapter("client", false)) + _ = engine.RegisterAdapter(newMockAdapter("provider", false)) + + converter, err := engine.CreateStreamConverter("client", "provider") + require.NoError(t, err) + _, ok := converter.(*CanonicalStreamConverter) + assert.True(t, ok) +} + +func TestEncodeError(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry) + _ = engine.RegisterAdapter(newMockAdapter("openai", true)) + + convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误") + body, statusCode, err := engine.EncodeError(convErr, "openai") + require.NoError(t, err) + assert.Equal(t, 400, statusCode) + assert.NotNil(t, body) +} + +func TestEncodeError_NonExistentProtocol(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry) + + convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误") + body, statusCode, err := engine.EncodeError(convErr, "nonexistent") + require.NoError(t, err) + assert.Equal(t, 500, statusCode) + assert.Contains(t, string(body), "测试错误") +} + +func TestRegistry_DuplicateRegistration(t *testing.T) { + registry := NewMemoryRegistry() + adapter := newMockAdapter("openai", true) + + err := registry.Register(adapter) + require.NoError(t, err) + + err = registry.Register(adapter) + assert.Error(t, err) + assert.Contains(t, err.Error(), "适配器已注册") +} + +func TestRegistry_GetNonExistent(t *testing.T) { + registry := NewMemoryRegistry() + + _, err := registry.Get("nonexistent") + assert.Error(t, err) + assert.Contains(t, err.Error(), "未找到适配器") +} diff --git a/backend/internal/conversion/errors.go b/backend/internal/conversion/errors.go new file mode 100644 index 0000000..9885532 --- /dev/null +++ b/backend/internal/conversion/errors.go @@ -0,0 +1,83 @@ +package conversion + +import "fmt" + +// ErrorCode 错误码枚举 +type ErrorCode string + +const ( + ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT" + ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD" + ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE" + ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE" + ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR" + ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR" + ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR" + ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR" + ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION" + ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE" + ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED" +) + +// ConversionError 协议转换错误 +type ConversionError struct { + Code ErrorCode + Message string + ClientProtocol string + ProviderProtocol string + InterfaceType string + Details map[string]any + Cause error +} + +// NewConversionError 创建转换错误 +func NewConversionError(code ErrorCode, message string) *ConversionError { + return &ConversionError{ + Code: code, + Message: message, + Details: make(map[string]any), + } +} + +// WithClientProtocol 设置客户端协议 +func (e *ConversionError) WithClientProtocol(protocol string) *ConversionError { + e.ClientProtocol = protocol + return e +} + +// WithProviderProtocol 设置服务端协议 +func (e *ConversionError) WithProviderProtocol(protocol string) *ConversionError { + e.ProviderProtocol = protocol + return e +} + +// WithInterfaceType 设置接口类型 +func (e *ConversionError) WithInterfaceType(ifaceType string) *ConversionError { + e.InterfaceType = ifaceType + return e +} + +// WithDetail 添加详情 +func (e *ConversionError) WithDetail(key string, value any) *ConversionError { + e.Details[key] = value + return e +} + +// WithCause 设置原因 +func (e *ConversionError) WithCause(cause error) *ConversionError { + e.Cause = cause + return e +} + +// Error 实现 error 接口 +func (e *ConversionError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("[%s] %s: %v", e.Code, e.Message, e.Cause) + } + return fmt.Sprintf("[%s] %s", e.Code, e.Message) +} + +// Unwrap 支持 errors.Is/As +func (e *ConversionError) Unwrap() error { + return e.Cause +} diff --git a/backend/internal/conversion/errors_test.go b/backend/internal/conversion/errors_test.go new file mode 100644 index 0000000..4f313a4 --- /dev/null +++ b/backend/internal/conversion/errors_test.go @@ -0,0 +1,45 @@ +package conversion + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConversionError_Builder(t *testing.T) { + cause := errors.New("原始错误") + err := NewConversionError(ErrorCodeInvalidInput, "输入无效"). + WithClientProtocol("openai"). + WithDetail("field", "model"). + WithCause(cause) + + assert.Equal(t, ErrorCodeInvalidInput, err.Code) + assert.Equal(t, "openai", err.ClientProtocol) + assert.Equal(t, "输入无效", err.Message) + assert.Equal(t, "model", err.Details["field"]) + assert.Equal(t, cause, err.Cause) +} + +func TestConversionError_Unwrap(t *testing.T) { + cause := errors.New("根本原因") + err := NewConversionError(ErrorCodeJSONParseError, "解析失败").WithCause(cause) + + unwrapped := err.Unwrap() + assert.Equal(t, cause, unwrapped) +} + +func TestConversionError_Error_WithCause(t *testing.T) { + err := NewConversionError(ErrorCodeInvalidInput, "输入无效").WithCause(errors.New("原因")) + msg := err.Error() + assert.Contains(t, msg, "INVALID_INPUT") + assert.Contains(t, msg, "输入无效") + assert.Contains(t, msg, "原因") +} + +func TestConversionError_Error_WithoutCause(t *testing.T) { + err := NewConversionError(ErrorCodeInvalidInput, "输入无效") + msg := err.Error() + assert.Contains(t, msg, "INVALID_INPUT") + assert.Contains(t, msg, "输入无效") +} diff --git a/backend/internal/conversion/interface.go b/backend/internal/conversion/interface.go new file mode 100644 index 0000000..45d775b --- /dev/null +++ b/backend/internal/conversion/interface.go @@ -0,0 +1,13 @@ +package conversion + +// InterfaceType 接口类型枚举 +type InterfaceType string + +const ( + InterfaceTypeChat InterfaceType = "CHAT" + InterfaceTypeModels InterfaceType = "MODELS" + InterfaceTypeModelInfo InterfaceType = "MODEL_INFO" + InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS" + InterfaceTypeRerank InterfaceType = "RERANK" + InterfaceTypePassthrough InterfaceType = "PASSTHROUGH" +) diff --git a/backend/internal/conversion/middleware.go b/backend/internal/conversion/middleware.go new file mode 100644 index 0000000..3e190c9 --- /dev/null +++ b/backend/internal/conversion/middleware.go @@ -0,0 +1,76 @@ +package conversion + +import ( + "time" + + "nex/backend/internal/conversion/canonical" + + "github.com/google/uuid" +) + +// ConversionMiddleware 转换中间件接口 +type ConversionMiddleware interface { + Intercept(req *canonical.CanonicalRequest, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) + InterceptStreamEvent(event *canonical.CanonicalStreamEvent, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalStreamEvent, error) +} + +// ConversionContext 转换上下文 +type ConversionContext struct { + ConversionID string + InterfaceType InterfaceType + Timestamp time.Time + Metadata map[string]any +} + +// NewConversionContext 创建转换上下文 +func NewConversionContext(ifaceType InterfaceType) *ConversionContext { + return &ConversionContext{ + ConversionID: uuid.New().String(), + InterfaceType: ifaceType, + Timestamp: time.Now().UTC(), + Metadata: make(map[string]any), + } +} + +// MiddlewareChain 中间件链 +type MiddlewareChain struct { + middlewares []ConversionMiddleware +} + +// NewMiddlewareChain 创建中间件链 +func NewMiddlewareChain() *MiddlewareChain { + return &MiddlewareChain{ + middlewares: make([]ConversionMiddleware, 0), + } +} + +// Use 添加中间件 +func (c *MiddlewareChain) Use(mw ConversionMiddleware) { + c.middlewares = append(c.middlewares, mw) +} + +// Apply 对请求按顺序执行所有中间件 +func (c *MiddlewareChain) Apply(req *canonical.CanonicalRequest, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) { + result := req + for _, mw := range c.middlewares { + var err error + result, err = mw.Intercept(result, clientProtocol, providerProtocol, ctx) + if err != nil { + return nil, err + } + } + return result, nil +} + +// ApplyStreamEvent 对流式事件按顺序执行所有中间件 +func (c *MiddlewareChain) ApplyStreamEvent(event *canonical.CanonicalStreamEvent, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalStreamEvent, error) { + result := event + for _, mw := range c.middlewares { + var err error + result, err = mw.InterceptStreamEvent(result, clientProtocol, providerProtocol, ctx) + if err != nil { + return nil, err + } + } + return result, nil +} diff --git a/backend/internal/conversion/middleware_test.go b/backend/internal/conversion/middleware_test.go new file mode 100644 index 0000000..9a8d38a --- /dev/null +++ b/backend/internal/conversion/middleware_test.go @@ -0,0 +1,85 @@ +package conversion + +import ( + "errors" + "testing" + + "nex/backend/internal/conversion/canonical" + + "github.com/stretchr/testify/assert" +) + +// recordingMiddleware 记录调用顺序的中间件 +type recordingMiddleware struct { + name string + records *[]string + err error +} + +func (m *recordingMiddleware) Intercept(req *canonical.CanonicalRequest, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) { + *m.records = append(*m.records, m.name) + if m.err != nil { + return nil, m.err + } + return req, nil +} + +func (m *recordingMiddleware) InterceptStreamEvent(event *canonical.CanonicalStreamEvent, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalStreamEvent, error) { + *m.records = append(*m.records, "stream:"+m.name) + if m.err != nil { + return nil, m.err + } + return event, nil +} + +func TestMiddlewareChain_Empty(t *testing.T) { + chain := NewMiddlewareChain() + req := &canonical.CanonicalRequest{Model: "test"} + ctx := NewConversionContext(InterfaceTypeChat) + + result, err := chain.Apply(req, "a", "b", ctx) + assert.NoError(t, err) + assert.Equal(t, "test", result.Model) +} + +func TestMiddlewareChain_Order(t *testing.T) { + var records []string + chain := NewMiddlewareChain() + chain.Use(&recordingMiddleware{name: "first", records: &records}) + chain.Use(&recordingMiddleware{name: "second", records: &records}) + chain.Use(&recordingMiddleware{name: "third", records: &records}) + + req := &canonical.CanonicalRequest{Model: "test"} + ctx := NewConversionContext(InterfaceTypeChat) + _, err := chain.Apply(req, "a", "b", ctx) + assert.NoError(t, err) + assert.Equal(t, []string{"first", "second", "third"}, records) +} + +func TestMiddlewareChain_ErrorInterrupt(t *testing.T) { + var records []string + chain := NewMiddlewareChain() + chain.Use(&recordingMiddleware{name: "first", records: &records}) + chain.Use(&recordingMiddleware{name: "second", records: &records, err: errors.New("中断")}) + chain.Use(&recordingMiddleware{name: "third", records: &records}) + + req := &canonical.CanonicalRequest{Model: "test"} + ctx := NewConversionContext(InterfaceTypeChat) + _, err := chain.Apply(req, "a", "b", ctx) + assert.Error(t, err) + assert.Equal(t, "中断", err.Error()) + assert.Equal(t, []string{"first", "second"}, records) +} + +func TestMiddlewareChain_ApplyStreamEvent(t *testing.T) { + var records []string + chain := NewMiddlewareChain() + chain.Use(&recordingMiddleware{name: "mw1", records: &records}) + + event := canonical.NewMessageStartEvent("id", "model") + ctx := NewConversionContext(InterfaceTypeChat) + result, err := chain.ApplyStreamEvent(&event, "a", "b", ctx) + assert.NoError(t, err) + assert.Equal(t, canonical.EventMessageStart, result.Type) + assert.Equal(t, []string{"stream:mw1"}, records) +} diff --git a/backend/internal/conversion/openai/adapter.go b/backend/internal/conversion/openai/adapter.go new file mode 100644 index 0000000..fdfb5e7 --- /dev/null +++ b/backend/internal/conversion/openai/adapter.go @@ -0,0 +1,211 @@ +package openai + +import ( + "encoding/json" + "regexp" + + "nex/backend/internal/conversion" + "nex/backend/internal/conversion/canonical" +) + +// Adapter OpenAI 协议适配器 +type Adapter struct{} + +// NewAdapter 创建 OpenAI 适配器 +func NewAdapter() *Adapter { + return &Adapter{} +} + +var modelInfoRegex = regexp.MustCompile(`^/v1/models/[^/]+$`) + +// ProtocolName 返回协议名称 +func (a *Adapter) ProtocolName() string { return "openai" } + +// ProtocolVersion 返回协议版本 +func (a *Adapter) ProtocolVersion() string { return "" } + +// SupportsPassthrough 支持同协议透传 +func (a *Adapter) SupportsPassthrough() bool { return true } + +// DetectInterfaceType 根据路径检测接口类型 +func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType { + switch { + case nativePath == "/v1/chat/completions": + return conversion.InterfaceTypeChat + case nativePath == "/v1/models": + return conversion.InterfaceTypeModels + case modelInfoRegex.MatchString(nativePath): + return conversion.InterfaceTypeModelInfo + case nativePath == "/v1/embeddings": + return conversion.InterfaceTypeEmbeddings + case nativePath == "/v1/rerank": + return conversion.InterfaceTypeRerank + default: + return conversion.InterfaceTypePassthrough + } +} + +// BuildUrl 根据接口类型构建 URL +func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string { + switch interfaceType { + case conversion.InterfaceTypeChat: + return "/v1/chat/completions" + case conversion.InterfaceTypeModels: + return "/v1/models" + case conversion.InterfaceTypeEmbeddings: + return "/v1/embeddings" + case conversion.InterfaceTypeRerank: + return "/v1/rerank" + default: + return nativePath + } +} + +// BuildHeaders 构建请求头 +func (a *Adapter) BuildHeaders(provider *conversion.TargetProvider) map[string]string { + headers := map[string]string{ + "Authorization": "Bearer " + provider.APIKey, + "Content-Type": "application/json", + } + if org, ok := provider.AdapterConfig["organization"].(string); ok && org != "" { + headers["OpenAI-Organization"] = org + } + return headers +} + +// SupportsInterface 检查是否支持接口类型 +func (a *Adapter) SupportsInterface(interfaceType conversion.InterfaceType) bool { + switch interfaceType { + case conversion.InterfaceTypeChat, + conversion.InterfaceTypeModels, + conversion.InterfaceTypeModelInfo, + conversion.InterfaceTypeEmbeddings, + conversion.InterfaceTypeRerank: + return true + default: + return false + } +} + +// DecodeRequest 解码请求 +func (a *Adapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) { + return decodeRequest(raw) +} + +// EncodeRequest 编码请求 +func (a *Adapter) EncodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) { + return encodeRequest(req, provider) +} + +// DecodeResponse 解码响应 +func (a *Adapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) { + return decodeResponse(raw) +} + +// EncodeResponse 编码响应 +func (a *Adapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) { + return encodeResponse(resp) +} + +// CreateStreamDecoder 创建流式解码器 +func (a *Adapter) CreateStreamDecoder() conversion.StreamDecoder { + return NewStreamDecoder() +} + +// CreateStreamEncoder 创建流式编码器 +func (a *Adapter) CreateStreamEncoder() conversion.StreamEncoder { + return NewStreamEncoder() +} + +// EncodeError 编码错误 +func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) { + errType := mapErrorCode(err.Code) + statusCode := 500 + + errMsg := ErrorResponse{ + Error: ErrorDetail{ + Message: err.Message, + Type: errType, + Param: nil, + Code: string(err.Code), + }, + } + body, _ := json.Marshal(errMsg) + return body, statusCode +} + +// mapErrorCode 映射错误码到 OpenAI 错误类型 +func mapErrorCode(code conversion.ErrorCode) string { + switch code { + case conversion.ErrorCodeInvalidInput, + conversion.ErrorCodeMissingRequiredField, + conversion.ErrorCodeIncompatibleFeature, + conversion.ErrorCodeToolCallParseError, + conversion.ErrorCodeJSONParseError, + conversion.ErrorCodeProtocolConstraint, + conversion.ErrorCodeFieldMappingFailure: + return "invalid_request_error" + default: + return "server_error" + } +} + +// DecodeModelsResponse 解码模型列表响应 +func (a *Adapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) { + return decodeModelsResponse(raw) +} + +// EncodeModelsResponse 编码模型列表响应 +func (a *Adapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) { + return encodeModelsResponse(list) +} + +// DecodeModelInfoResponse 解码模型详情响应 +func (a *Adapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) { + return decodeModelInfoResponse(raw) +} + +// EncodeModelInfoResponse 编码模型详情响应 +func (a *Adapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) { + return encodeModelInfoResponse(info) +} + +// DecodeEmbeddingRequest 解码嵌入请求 +func (a *Adapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) { + return decodeEmbeddingRequest(raw) +} + +// EncodeEmbeddingRequest 编码嵌入请求 +func (a *Adapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) { + return encodeEmbeddingRequest(req, provider) +} + +// DecodeEmbeddingResponse 解码嵌入响应 +func (a *Adapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) { + return decodeEmbeddingResponse(raw) +} + +// EncodeEmbeddingResponse 编码嵌入响应 +func (a *Adapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) { + return encodeEmbeddingResponse(resp) +} + +// DecodeRerankRequest 解码重排序请求 +func (a *Adapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) { + return decodeRerankRequest(raw) +} + +// EncodeRerankRequest 编码重排序请求 +func (a *Adapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) { + return encodeRerankRequest(req, provider) +} + +// DecodeRerankResponse 解码重排序响应 +func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) { + return decodeRerankResponse(raw) +} + +// EncodeRerankResponse 编码重排序响应 +func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) { + return encodeRerankResponse(resp) +} diff --git a/backend/internal/conversion/openai/adapter_test.go b/backend/internal/conversion/openai/adapter_test.go new file mode 100644 index 0000000..c220dbd --- /dev/null +++ b/backend/internal/conversion/openai/adapter_test.go @@ -0,0 +1,139 @@ +package openai + +import ( + "encoding/json" + "testing" + + "nex/backend/internal/conversion" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAdapter_ProtocolName(t *testing.T) { + a := NewAdapter() + assert.Equal(t, "openai", a.ProtocolName()) +} + +func TestAdapter_SupportsPassthrough(t *testing.T) { + a := NewAdapter() + assert.True(t, a.SupportsPassthrough()) +} + +func TestAdapter_DetectInterfaceType(t *testing.T) { + a := NewAdapter() + + tests := []struct { + name string + path string + expected conversion.InterfaceType + }{ + {"聊天补全", "/v1/chat/completions", conversion.InterfaceTypeChat}, + {"模型列表", "/v1/models", conversion.InterfaceTypeModels}, + {"模型详情", "/v1/models/gpt-4", conversion.InterfaceTypeModelInfo}, + {"嵌入接口", "/v1/embeddings", conversion.InterfaceTypeEmbeddings}, + {"重排序接口", "/v1/rerank", conversion.InterfaceTypeRerank}, + {"未知路径", "/v1/unknown", conversion.InterfaceTypePassthrough}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := a.DetectInterfaceType(tt.path) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestAdapter_BuildUrl(t *testing.T) { + a := NewAdapter() + + tests := []struct { + name string + nativePath string + interfaceType conversion.InterfaceType + expected string + }{ + {"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/v1/chat/completions"}, + {"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"}, + {"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/v1/embeddings"}, + {"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/v1/rerank"}, + {"默认透传", "/v1/other", conversion.InterfaceTypePassthrough, "/v1/other"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := a.BuildUrl(tt.nativePath, tt.interfaceType) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestAdapter_BuildHeaders(t *testing.T) { + a := NewAdapter() + + t.Run("基本头", func(t *testing.T) { + provider := conversion.NewTargetProvider("https://api.openai.com", "sk-test123", "gpt-4") + headers := a.BuildHeaders(provider) + assert.Equal(t, "Bearer sk-test123", headers["Authorization"]) + assert.Equal(t, "application/json", headers["Content-Type"]) + _, hasOrg := headers["OpenAI-Organization"] + assert.False(t, hasOrg) + }) + + t.Run("带组织", func(t *testing.T) { + provider := conversion.NewTargetProvider("https://api.openai.com", "sk-test123", "gpt-4") + provider.AdapterConfig["organization"] = "org-abc" + headers := a.BuildHeaders(provider) + assert.Equal(t, "org-abc", headers["OpenAI-Organization"]) + }) +} + +func TestAdapter_SupportsInterface(t *testing.T) { + a := NewAdapter() + + tests := []struct { + name string + interfaceType conversion.InterfaceType + expected bool + }{ + {"聊天", conversion.InterfaceTypeChat, true}, + {"模型", conversion.InterfaceTypeModels, true}, + {"模型详情", conversion.InterfaceTypeModelInfo, true}, + {"嵌入", conversion.InterfaceTypeEmbeddings, true}, + {"重排序", conversion.InterfaceTypeRerank, true}, + {"透传", conversion.InterfaceTypePassthrough, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := a.SupportsInterface(tt.interfaceType) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestAdapter_EncodeError_InvalidInput(t *testing.T) { + a := NewAdapter() + convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效") + + body, statusCode := a.EncodeError(convErr) + require.Equal(t, 500, statusCode) + + var resp ErrorResponse + require.NoError(t, json.Unmarshal(body, &resp)) + assert.Equal(t, "参数无效", resp.Error.Message) + assert.Equal(t, "invalid_request_error", resp.Error.Type) +} + +func TestAdapter_EncodeError_ServerError(t *testing.T) { + a := NewAdapter() + convErr := conversion.NewConversionError(conversion.ErrorCodeStreamStateError, "流状态错误") + + body, statusCode := a.EncodeError(convErr) + require.Equal(t, 500, statusCode) + + var resp ErrorResponse + require.NoError(t, json.Unmarshal(body, &resp)) + assert.Equal(t, "server_error", resp.Error.Type) + assert.Equal(t, "流状态错误", resp.Error.Message) +} diff --git a/backend/internal/conversion/openai/decoder.go b/backend/internal/conversion/openai/decoder.go new file mode 100644 index 0000000..0738cb9 --- /dev/null +++ b/backend/internal/conversion/openai/decoder.go @@ -0,0 +1,669 @@ +package openai + +import ( + "encoding/json" + "fmt" + "strings" + "sync/atomic" + + "nex/backend/internal/conversion" + "nex/backend/internal/conversion/canonical" +) + +// decodeRequest 将 OpenAI 请求解码为 Canonical 请求 +func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) { + var req ChatCompletionRequest + if err := json.Unmarshal(body, &req); err != nil { + return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 OpenAI 请求失败").WithCause(err) + } + + if strings.TrimSpace(req.Model) == "" { + return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "model 字段不能为空") + } + if len(req.Messages) == 0 { + return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "messages 字段不能为空") + } + + // 废弃字段兼容 + decodeDeprecatedFields(&req) + + system, messages := decodeSystemPrompt(req.Messages) + + var canonicalMsgs []canonical.CanonicalMessage + for _, msg := range messages { + decoded, err := decodeMessage(msg) + if err != nil { + return nil, err + } + canonicalMsgs = append(canonicalMsgs, decoded...) + } + + tools := decodeTools(req.Tools) + toolChoice := decodeToolChoice(req.ToolChoice) + params := decodeParameters(&req) + outputFormat := decodeOutputFormat(req.ResponseFormat) + thinking := decodeThinking(req.ReasoningEffort) + + var parallelToolUse *bool + if req.ParallelToolCalls != nil { + parallelToolUse = req.ParallelToolCalls + } + + return &canonical.CanonicalRequest{ + Model: req.Model, + System: system, + Messages: canonicalMsgs, + Tools: tools, + ToolChoice: toolChoice, + Parameters: params, + Thinking: thinking, + Stream: req.Stream, + UserID: req.User, + OutputFormat: outputFormat, + ParallelToolUse: parallelToolUse, + }, nil +} + +// decodeSystemPrompt 提取 system 和 developer 消息 +func decodeSystemPrompt(messages []Message) (any, []Message) { + var systemParts []string + var remaining []Message + + for _, msg := range messages { + if msg.Role == "system" || msg.Role == "developer" { + text := extractText(msg.Content) + if text != "" { + systemParts = append(systemParts, text) + } + } else { + remaining = append(remaining, msg) + } + } + + if len(systemParts) == 0 { + return nil, remaining + } + return strings.Join(systemParts, "\n\n"), remaining +} + +// extractText 从 content 提取文本 +func extractText(content any) string { + switch v := content.(type) { + case string: + return v + case []any: + var parts []string + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if t, ok := m["type"].(string); ok && t == "text" { + if text, ok := m["text"].(string); ok { + parts = append(parts, text) + } + } + } + } + return strings.Join(parts, "") + case nil: + return "" + default: + return fmt.Sprintf("%v", v) + } +} + +// decodeMessage 解码 OpenAI 消息 +func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) { + switch msg.Role { + case "user": + blocks := decodeUserContent(msg.Content) + return []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: blocks}}, nil + + case "assistant": + var blocks []canonical.ContentBlock + // 处理 content + if msg.Content != nil { + switch v := msg.Content.(type) { + case string: + if v != "" { + blocks = append(blocks, canonical.NewTextBlock(v)) + } + default: + parts := decodeContentParts(msg.Content) + for _, p := range parts { + if p.Type == "text" { + blocks = append(blocks, canonical.NewTextBlock(p.Text)) + } else if p.Type == "refusal" { + blocks = append(blocks, canonical.NewTextBlock(p.Refusal)) + } + } + } + } + // refusal 顶层字段 + if msg.Refusal != "" { + blocks = append(blocks, canonical.NewTextBlock(msg.Refusal)) + } + // reasoning_content 非标准字段 + if msg.ReasoningContent != "" { + blocks = append(blocks, canonical.NewThinkingBlock(msg.ReasoningContent)) + } + // tool_calls + for _, tc := range msg.ToolCalls { + var input json.RawMessage + if tc.Type == "custom" && tc.Custom != nil { + input = json.RawMessage(fmt.Sprintf("%q", tc.Custom.Input)) + } else if tc.Function != nil { + parsed := json.RawMessage(tc.Function.Arguments) + if !json.Valid(parsed) { + parsed = json.RawMessage("{}") + } + input = parsed + } else { + input = json.RawMessage("{}") + } + name := "" + if tc.Function != nil { + name = tc.Function.Name + } else if tc.Custom != nil { + name = tc.Custom.Name + } + blocks = append(blocks, canonical.NewToolUseBlock(tc.ID, name, input)) + } + // 已废弃 function_call + if msg.FunctionCall != nil { + input := json.RawMessage(msg.FunctionCall.Arguments) + if !json.Valid(input) { + input = json.RawMessage("{}") + } + blocks = append(blocks, canonical.NewToolUseBlock(generateID(), msg.FunctionCall.Name, input)) + } + if len(blocks) == 0 { + blocks = append(blocks, canonical.NewTextBlock("")) + } + return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil + + case "tool": + content := extractText(msg.Content) + isErr := false + block := canonical.ContentBlock{ + Type: "tool_result", + ToolUseID: msg.ToolCallID, + Content: json.RawMessage(fmt.Sprintf("%q", content)), + IsError: &isErr, + } + return []canonical.CanonicalMessage{{Role: canonical.RoleTool, Content: []canonical.ContentBlock{block}}}, nil + + case "function": + content := extractText(msg.Content) + isErr := false + block := canonical.ContentBlock{ + Type: "tool_result", + ToolUseID: msg.Name, + Content: json.RawMessage(fmt.Sprintf("%q", content)), + IsError: &isErr, + } + return []canonical.CanonicalMessage{{Role: canonical.RoleTool, Content: []canonical.ContentBlock{block}}}, nil + } + + return nil, nil +} + +// decodeUserContent 解码用户内容 +func decodeUserContent(content any) []canonical.ContentBlock { + switch v := content.(type) { + case string: + return []canonical.ContentBlock{canonical.NewTextBlock(v)} + case []any: + var blocks []canonical.ContentBlock + for _, item := range v { + if m, ok := item.(map[string]any); ok { + t, _ := m["type"].(string) + switch t { + case "text": + text, _ := m["text"].(string) + blocks = append(blocks, canonical.NewTextBlock(text)) + case "image_url": + blocks = append(blocks, canonical.ContentBlock{Type: "image"}) + case "input_audio": + blocks = append(blocks, canonical.ContentBlock{Type: "audio"}) + case "file": + blocks = append(blocks, canonical.ContentBlock{Type: "file"}) + } + } + } + if len(blocks) > 0 { + return blocks + } + return []canonical.ContentBlock{canonical.NewTextBlock("")} + case nil: + return []canonical.ContentBlock{canonical.NewTextBlock("")} + default: + return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))} + } +} + +// contentPart 内容部分 +type contentPart struct { + Type string + Text string + Refusal string +} + +// decodeContentParts 解码内容部分 +func decodeContentParts(content any) []contentPart { + parts, ok := content.([]any) + if !ok { + return nil + } + var result []contentPart + for _, item := range parts { + if m, ok := item.(map[string]any); ok { + t, _ := m["type"].(string) + switch t { + case "text": + text, _ := m["text"].(string) + result = append(result, contentPart{Type: "text", Text: text}) + case "refusal": + refusal, _ := m["refusal"].(string) + result = append(result, contentPart{Type: "refusal", Refusal: refusal}) + } + } + } + return result +} + +// decodeTools 解码工具定义 +func decodeTools(tools []Tool) []canonical.CanonicalTool { + if len(tools) == 0 { + return nil + } + var result []canonical.CanonicalTool + for _, tool := range tools { + if tool.Type == "function" && tool.Function != nil { + result = append(result, canonical.CanonicalTool{ + Name: tool.Function.Name, + Description: tool.Function.Description, + InputSchema: tool.Function.Parameters, + }) + } + } + if len(result) == 0 { + return nil + } + return result +} + +// decodeToolChoice 解码工具选择 +func decodeToolChoice(toolChoice any) *canonical.ToolChoice { + if toolChoice == nil { + return nil + } + switch v := toolChoice.(type) { + case string: + switch v { + case "auto": + return canonical.NewToolChoiceAuto() + case "none": + return canonical.NewToolChoiceNone() + case "required": + return canonical.NewToolChoiceAny() + } + case map[string]any: + t, _ := v["type"].(string) + switch t { + case "function": + if fn, ok := v["function"].(map[string]any); ok { + name, _ := fn["name"].(string) + return canonical.NewToolChoiceNamed(name) + } + case "custom": + if custom, ok := v["custom"].(map[string]any); ok { + name, _ := custom["name"].(string) + return canonical.NewToolChoiceNamed(name) + } + case "allowed_tools": + if at, ok := v["allowed_tools"].(map[string]any); ok { + mode, _ := at["mode"].(string) + if mode == "required" { + return canonical.NewToolChoiceAny() + } + return canonical.NewToolChoiceAuto() + } + return canonical.NewToolChoiceAuto() + } + } + return nil +} + +// decodeParameters 解码请求参数 +func decodeParameters(req *ChatCompletionRequest) canonical.RequestParameters { + params := canonical.RequestParameters{ + Temperature: req.Temperature, + TopP: req.TopP, + FrequencyPenalty: req.FrequencyPenalty, + PresencePenalty: req.PresencePenalty, + } + if req.MaxCompletionTokens != nil { + params.MaxTokens = req.MaxCompletionTokens + } else if req.MaxTokens != nil { + params.MaxTokens = req.MaxTokens + } + if req.Stop != nil { + params.StopSequences = normalizeStop(req.Stop) + } + return params +} + +// normalizeStop 规范化 stop 参数 +func normalizeStop(stop any) []string { + switch v := stop.(type) { + case string: + if v == "" { + return nil + } + return []string{v} + case []any: + var result []string + for _, s := range v { + if str, ok := s.(string); ok && str != "" { + result = append(result, str) + } + } + if len(result) == 0 { + return nil + } + return result + case []string: + return v + } + return nil +} + +// decodeOutputFormat 解码输出格式 +func decodeOutputFormat(format *ResponseFormat) *canonical.OutputFormat { + if format == nil { + return nil + } + switch format.Type { + case "json_object": + return &canonical.OutputFormat{Type: "json_object"} + case "json_schema": + if format.JSONSchema != nil { + return &canonical.OutputFormat{ + Type: "json_schema", + Name: format.JSONSchema.Name, + Schema: format.JSONSchema.Schema, + Strict: format.JSONSchema.Strict, + } + } + return &canonical.OutputFormat{Type: "json_schema"} + case "text": + return nil + } + return nil +} + +// decodeThinking 解码推理配置 +func decodeThinking(reasoningEffort string) *canonical.ThinkingConfig { + if reasoningEffort == "" { + return nil + } + if reasoningEffort == "none" { + return &canonical.ThinkingConfig{Type: "disabled"} + } + effort := reasoningEffort + if effort == "minimal" { + effort = "low" + } + return &canonical.ThinkingConfig{Type: "enabled", Effort: effort} +} + +// decodeDeprecatedFields 废弃字段兼容 +func decodeDeprecatedFields(req *ChatCompletionRequest) { + if len(req.Tools) == 0 && len(req.Functions) > 0 { + req.Tools = make([]Tool, len(req.Functions)) + for i, f := range req.Functions { + req.Tools[i] = Tool{ + Type: "function", + Function: &FunctionDef{ + Name: f.Name, + Description: f.Description, + Parameters: f.Parameters, + }, + } + } + } + if req.ToolChoice == nil && req.FunctionCall != nil { + switch v := req.FunctionCall.(type) { + case string: + switch v { + case "none": + req.ToolChoice = "none" + case "auto": + req.ToolChoice = "auto" + } + case map[string]any: + if name, ok := v["name"].(string); ok { + req.ToolChoice = map[string]any{ + "type": "function", + "function": map[string]any{"name": name}, + } + } + } + } +} + +// decodeResponse 将 OpenAI 响应解码为 Canonical 响应 +func decodeResponse(body []byte) (*canonical.CanonicalResponse, error) { + var resp ChatCompletionResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 OpenAI 响应失败").WithCause(err) + } + + if len(resp.Choices) == 0 { + return &canonical.CanonicalResponse{ + ID: resp.ID, + Model: resp.Model, + Content: []canonical.ContentBlock{canonical.NewTextBlock("")}, + Usage: canonical.CanonicalUsage{}, + }, nil + } + + choice := resp.Choices[0] + var blocks []canonical.ContentBlock + + if choice.Message != nil { + if choice.Message.Content != nil { + text := extractText(choice.Message.Content) + if text != "" { + blocks = append(blocks, canonical.NewTextBlock(text)) + } + } + if choice.Message.Refusal != "" { + blocks = append(blocks, canonical.NewTextBlock(choice.Message.Refusal)) + } + if choice.Message.ReasoningContent != "" { + blocks = append(blocks, canonical.NewThinkingBlock(choice.Message.ReasoningContent)) + } + for _, tc := range choice.Message.ToolCalls { + var input json.RawMessage + name := "" + if tc.Type == "custom" && tc.Custom != nil { + input = json.RawMessage(fmt.Sprintf("%q", tc.Custom.Input)) + name = tc.Custom.Name + } else if tc.Function != nil { + input = json.RawMessage(tc.Function.Arguments) + if !json.Valid(input) { + input = json.RawMessage("{}") + } + name = tc.Function.Name + } else { + input = json.RawMessage("{}") + } + blocks = append(blocks, canonical.NewToolUseBlock(tc.ID, name, input)) + } + } + + if len(blocks) == 0 { + blocks = append(blocks, canonical.NewTextBlock("")) + } + + var stopReason *canonical.StopReason + if choice.FinishReason != nil { + sr := mapFinishReason(*choice.FinishReason) + stopReason = &sr + } + + return &canonical.CanonicalResponse{ + ID: resp.ID, + Model: resp.Model, + Content: blocks, + StopReason: stopReason, + Usage: decodeUsage(resp.Usage), + }, nil +} + +// mapFinishReason 映射结束原因 +func mapFinishReason(reason string) canonical.StopReason { + switch reason { + case "stop": + return canonical.StopReasonEndTurn + case "length": + return canonical.StopReasonMaxTokens + case "tool_calls": + return canonical.StopReasonToolUse + case "function_call": + return canonical.StopReasonToolUse + case "content_filter": + return canonical.StopReasonContentFilter + default: + return canonical.StopReasonEndTurn + } +} + +// decodeUsage 解码用量 +func decodeUsage(usage *Usage) canonical.CanonicalUsage { + if usage == nil { + return canonical.CanonicalUsage{} + } + result := canonical.CanonicalUsage{ + InputTokens: usage.PromptTokens, + OutputTokens: usage.CompletionTokens, + } + if usage.PromptTokensDetails != nil && usage.PromptTokensDetails.CachedTokens > 0 { + val := usage.PromptTokensDetails.CachedTokens + result.CacheReadTokens = &val + } + if usage.CompletionTokensDetails != nil && usage.CompletionTokensDetails.ReasoningTokens > 0 { + val := usage.CompletionTokensDetails.ReasoningTokens + result.ReasoningTokens = &val + } + return result +} + +// decodeModelsResponse 解码模型列表响应 +func decodeModelsResponse(body []byte) (*canonical.CanonicalModelList, error) { + var resp ModelsResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, err + } + models := make([]canonical.CanonicalModel, len(resp.Data)) + for i, m := range resp.Data { + models[i] = canonical.CanonicalModel{ + ID: m.ID, + Name: m.ID, + Created: m.Created, + OwnedBy: m.OwnedBy, + } + } + return &canonical.CanonicalModelList{Models: models}, nil +} + +// decodeModelInfoResponse 解码模型详情响应 +func decodeModelInfoResponse(body []byte) (*canonical.CanonicalModelInfo, error) { + var resp ModelInfoResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, err + } + return &canonical.CanonicalModelInfo{ + ID: resp.ID, + Name: resp.ID, + Created: resp.Created, + OwnedBy: resp.OwnedBy, + }, nil +} + +// decodeEmbeddingRequest 解码嵌入请求 +func decodeEmbeddingRequest(body []byte) (*canonical.CanonicalEmbeddingRequest, error) { + var req EmbeddingRequest + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + return &canonical.CanonicalEmbeddingRequest{ + Model: req.Model, + Input: req.Input, + EncodingFormat: req.EncodingFormat, + Dimensions: req.Dimensions, + }, nil +} + +// decodeEmbeddingResponse 解码嵌入响应 +func decodeEmbeddingResponse(body []byte) (*canonical.CanonicalEmbeddingResponse, error) { + var resp EmbeddingResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, err + } + data := make([]canonical.EmbeddingData, len(resp.Data)) + for i, d := range resp.Data { + data[i] = canonical.EmbeddingData{Index: d.Index, Embedding: d.Embedding} + } + return &canonical.CanonicalEmbeddingResponse{ + Data: data, + Model: resp.Model, + Usage: canonical.EmbeddingUsage{ + PromptTokens: resp.Usage.PromptTokens, + TotalTokens: resp.Usage.TotalTokens, + }, + }, nil +} + +// decodeRerankRequest 解码重排序请求 +func decodeRerankRequest(body []byte) (*canonical.CanonicalRerankRequest, error) { + var req RerankRequest + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + return &canonical.CanonicalRerankRequest{ + Model: req.Model, + Query: req.Query, + Documents: req.Documents, + TopN: req.TopN, + ReturnDocuments: req.ReturnDocuments, + }, nil +} + +// decodeRerankResponse 解码重排序响应 +func decodeRerankResponse(body []byte) (*canonical.CanonicalRerankResponse, error) { + var resp RerankResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, err + } + results := make([]canonical.RerankResult, len(resp.Results)) + for i, r := range resp.Results { + results[i] = canonical.RerankResult{ + Index: r.Index, + RelevanceScore: r.RelevanceScore, + Document: r.Document, + } + } + return &canonical.CanonicalRerankResponse{Results: results, Model: resp.Model}, nil +} + +// generateID 生成唯一 ID +func generateID() string { + return fmt.Sprintf("call_%d", generateCounter()) +} + +var idCounter int64 + +func generateCounter() int64 { + return atomic.AddInt64(&idCounter, 1) +} diff --git a/backend/internal/conversion/openai/decoder_test.go b/backend/internal/conversion/openai/decoder_test.go new file mode 100644 index 0000000..3babfc0 --- /dev/null +++ b/backend/internal/conversion/openai/decoder_test.go @@ -0,0 +1,411 @@ +package openai + +import ( + "fmt" + "testing" + + "nex/backend/internal/conversion/canonical" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecodeRequest_BasicChat(t *testing.T) { + body := []byte(`{ + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "你好"} + ], + "temperature": 0.7 + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + assert.Equal(t, "gpt-4", req.Model) + assert.Len(t, req.Messages, 1) + assert.Equal(t, canonical.RoleUser, req.Messages[0].Role) + assert.NotNil(t, req.Parameters.Temperature) + assert.Equal(t, 0.7, *req.Parameters.Temperature) +} + +func TestDecodeRequest_SystemAndDeveloper(t *testing.T) { + body := []byte(`{ + "model": "gpt-4", + "messages": [ + {"role": "system", "content": "你是助手"}, + {"role": "developer", "content": "额外指令"}, + {"role": "user", "content": "你好"} + ] + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + assert.Equal(t, "你是助手\n\n额外指令", req.System) + assert.Len(t, req.Messages, 1) + assert.Equal(t, canonical.RoleUser, req.Messages[0].Role) +} + +func TestDecodeRequest_ToolCalls(t *testing.T) { + body := []byte(`{ + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "天气"}, + { + "role": "assistant", + "tool_calls": [{ + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": "{\"city\":\"北京\"}"} + }] + } + ] + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + assert.Len(t, req.Messages, 2) + assistantMsg := req.Messages[1] + assert.Equal(t, canonical.RoleAssistant, assistantMsg.Role) + found := false + for _, b := range assistantMsg.Content { + if b.Type == "tool_use" { + found = true + assert.Equal(t, "call_123", b.ID) + assert.Equal(t, "get_weather", b.Name) + } + } + assert.True(t, found) +} + +func TestDecodeRequest_ToolMessage(t *testing.T) { + body := []byte(`{ + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "天气"}, + { + "role": "assistant", + "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}] + }, + { + "role": "tool", + "tool_call_id": "call_1", + "content": "晴天 25°C" + } + ] + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + toolMsg := req.Messages[2] + assert.Equal(t, canonical.RoleTool, toolMsg.Role) + assert.Equal(t, "call_1", toolMsg.Content[0].ToolUseID) +} + +func TestDecodeRequest_MissingModel(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`) + _, err := decodeRequest(body) + require.Error(t, err) + assert.Contains(t, err.Error(), "INVALID_INPUT") +} + +func TestDecodeRequest_MissingMessages(t *testing.T) { + body := []byte(`{"model":"gpt-4"}`) + _, err := decodeRequest(body) + require.Error(t, err) + assert.Contains(t, err.Error(), "INVALID_INPUT") +} + +func TestDecodeRequest_DeprecatedFunctions(t *testing.T) { + body := []byte(`{ + "model": "gpt-4", + "messages": [{"role": "user", "content": "test"}], + "functions": [{ + "name": "get_weather", + "description": "获取天气", + "parameters": {"type":"object","properties":{"city":{"type":"string"}}} + }] + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + assert.Len(t, req.Tools, 1) + assert.Equal(t, "get_weather", req.Tools[0].Name) +} + +func TestDecodeResponse_Basic(t *testing.T) { + body := []byte(`{ + "id": "chatcmpl-123", + "model": "gpt-4", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "你好"}, + "finish_reason": "stop" + }], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + }`) + + resp, err := decodeResponse(body) + require.NoError(t, err) + assert.Equal(t, "chatcmpl-123", resp.ID) + assert.Equal(t, "gpt-4", resp.Model) + assert.Len(t, resp.Content, 1) + assert.Equal(t, "你好", resp.Content[0].Text) + assert.NotNil(t, resp.StopReason) + assert.Equal(t, canonical.StopReasonEndTurn, *resp.StopReason) + assert.Equal(t, 10, resp.Usage.InputTokens) + assert.Equal(t, 5, resp.Usage.OutputTokens) +} + +func TestDecodeResponse_ToolCalls(t *testing.T) { + body := []byte(`{ + "id": "chatcmpl-456", + "model": "gpt-4", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "tool_calls": [{ + "id": "call_abc", + "type": "function", + "function": {"name": "search", "arguments": "{\"q\":\"test\"}"} + }] + }, + "finish_reason": "tool_calls" + }] + }`) + + resp, err := decodeResponse(body) + require.NoError(t, err) + found := false + for _, b := range resp.Content { + if b.Type == "tool_use" { + found = true + assert.Equal(t, "call_abc", b.ID) + assert.Equal(t, "search", b.Name) + } + } + assert.True(t, found) + assert.Equal(t, canonical.StopReasonToolUse, *resp.StopReason) +} + +func TestDecodeResponse_Thinking(t *testing.T) { + body := []byte(`{ + "id": "chatcmpl-789", + "model": "gpt-4", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "回答", + "reasoning_content": "思考过程" + }, + "finish_reason": "stop" + }] + }`) + + resp, err := decodeResponse(body) + require.NoError(t, err) + assert.Len(t, resp.Content, 2) + assert.Equal(t, "回答", resp.Content[0].Text) + assert.Equal(t, "thinking", resp.Content[1].Type) + assert.Equal(t, "思考过程", resp.Content[1].Thinking) +} + +func TestDecodeModelsResponse(t *testing.T) { + body := []byte(`{ + "object": "list", + "data": [ + {"id": "gpt-4", "object": "model", "created": 1700000000, "owned_by": "openai"}, + {"id": "gpt-3.5-turbo", "object": "model", "created": 1700000001, "owned_by": "openai"} + ] + }`) + + list, err := decodeModelsResponse(body) + require.NoError(t, err) + assert.Len(t, list.Models, 2) + assert.Equal(t, "gpt-4", list.Models[0].ID) + assert.Equal(t, "gpt-3.5-turbo", list.Models[1].ID) + assert.Equal(t, int64(1700000000), list.Models[0].Created) +} + +func TestDecodeRequest_InvalidJSON(t *testing.T) { + _, err := decodeRequest([]byte(`invalid json`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "JSON_PARSE_ERROR") +} + +func TestDecodeRequest_Parameters(t *testing.T) { + body := []byte(`{ + "model": "gpt-4", + "messages": [{"role": "user", "content": "hi"}], + "temperature": 0.5, + "max_completion_tokens": 2048, + "top_p": 0.9, + "frequency_penalty": 0.1, + "presence_penalty": 0.2, + "stop": ["STOP"] + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + assert.NotNil(t, req.Parameters.Temperature) + assert.Equal(t, 0.5, *req.Parameters.Temperature) + assert.NotNil(t, req.Parameters.MaxTokens) + assert.Equal(t, 2048, *req.Parameters.MaxTokens) + assert.NotNil(t, req.Parameters.TopP) + assert.Equal(t, 0.9, *req.Parameters.TopP) + assert.NotNil(t, req.Parameters.FrequencyPenalty) + assert.Equal(t, 0.1, *req.Parameters.FrequencyPenalty) + assert.NotNil(t, req.Parameters.PresencePenalty) + assert.Equal(t, 0.2, *req.Parameters.PresencePenalty) + assert.Equal(t, []string{"STOP"}, req.Parameters.StopSequences) +} + +func TestDecodeRequest_ToolChoice(t *testing.T) { + tests := []struct { + name string + jsonBody string + want *canonical.ToolChoice + }{ + { + name: "auto", + jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":"auto"}`, + want: canonical.NewToolChoiceAuto(), + }, + { + name: "none", + jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":"none"}`, + want: canonical.NewToolChoiceNone(), + }, + { + name: "required", + jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":"required"}`, + want: canonical.NewToolChoiceAny(), + }, + { + name: "named", + jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":{"type":"function","function":{"name":"x"}}}`, + want: canonical.NewToolChoiceNamed("x"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := decodeRequest([]byte(tt.jsonBody)) + require.NoError(t, err) + require.NotNil(t, req.ToolChoice) + assert.Equal(t, tt.want.Type, req.ToolChoice.Type) + assert.Equal(t, tt.want.Name, req.ToolChoice.Name) + }) + } +} + +func TestDecodeRequest_OutputFormat_JSONSchema(t *testing.T) { + body := []byte(`{ + "model": "gpt-4", + "messages": [{"role": "user", "content": "hi"}], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "my_schema", + "schema": {"type":"object","properties":{"name":{"type":"string"}}}, + "strict": true + } + } + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + require.NotNil(t, req.OutputFormat) + assert.Equal(t, "json_schema", req.OutputFormat.Type) + assert.Equal(t, "my_schema", req.OutputFormat.Name) + assert.NotNil(t, req.OutputFormat.Schema) + require.NotNil(t, req.OutputFormat.Strict) + assert.True(t, *req.OutputFormat.Strict) +} + +func TestDecodeRequest_OutputFormat_JSON(t *testing.T) { + body := []byte(`{ + "model": "gpt-4", + "messages": [{"role": "user", "content": "hi"}], + "response_format": {"type": "json_object"} + }`) + + req, err := decodeRequest(body) + require.NoError(t, err) + require.NotNil(t, req.OutputFormat) + assert.Equal(t, "json_object", req.OutputFormat.Type) +} + +func TestDecodeResponse_StopReasons(t *testing.T) { + tests := []struct { + name string + finishReason string + want canonical.StopReason + }{ + {"stop→end_turn", "stop", canonical.StopReasonEndTurn}, + {"length→max_tokens", "length", canonical.StopReasonMaxTokens}, + {"tool_calls→tool_use", "tool_calls", canonical.StopReasonToolUse}, + {"content_filter→content_filter", "content_filter", canonical.StopReasonContentFilter}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body := []byte(fmt.Sprintf(`{ + "id": "resp-1", + "model": "gpt-4", + "choices": [{"index": 0, "message": {"role": "assistant", "content": "ok"}, "finish_reason": "%s"}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2} + }`, tt.finishReason)) + + resp, err := decodeResponse(body) + require.NoError(t, err) + require.NotNil(t, resp.StopReason) + assert.Equal(t, tt.want, *resp.StopReason) + }) + } +} + +func TestDecodeResponse_Usage(t *testing.T) { + body := []byte(`{ + "id": "resp-1", + "model": "gpt-4", + "choices": [{"index": 0, "message": {"role": "assistant", "content": "ok"}, "finish_reason": "stop"}], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "prompt_tokens_details": {"cached_tokens": 80} + } + }`) + + resp, err := decodeResponse(body) + require.NoError(t, err) + assert.Equal(t, 100, resp.Usage.InputTokens) + assert.Equal(t, 50, resp.Usage.OutputTokens) + require.NotNil(t, resp.Usage.CacheReadTokens) + assert.Equal(t, 80, *resp.Usage.CacheReadTokens) +} + +func TestDecodeResponse_Refusal(t *testing.T) { + body := []byte(`{ + "id": "resp-1", + "model": "gpt-4", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": null, "refusal": "我拒绝回答"}, + "finish_reason": "stop" + }], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2} + }`) + + resp, err := decodeResponse(body) + require.NoError(t, err) + found := false + for _, b := range resp.Content { + if b.Text == "我拒绝回答" { + found = true + } + } + assert.True(t, found) +} diff --git a/backend/internal/conversion/openai/encoder.go b/backend/internal/conversion/openai/encoder.go new file mode 100644 index 0000000..ff6de7b --- /dev/null +++ b/backend/internal/conversion/openai/encoder.go @@ -0,0 +1,532 @@ +package openai + +import ( + "encoding/json" + "time" + + "nex/backend/internal/conversion" + "nex/backend/internal/conversion/canonical" +) + +// encodeRequest 将 Canonical 请求编码为 OpenAI 请求 +func encodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) { + result := map[string]any{ + "model": provider.ModelName, + "stream": req.Stream, + } + + // 系统消息 + 消息 + messages := encodeSystemAndMessages(req) + result["messages"] = messages + + // 参数 + encodeParametersInto(req, result) + + // 工具 + if len(req.Tools) > 0 { + tools := make([]map[string]any, len(req.Tools)) + for i, t := range req.Tools { + tools[i] = map[string]any{ + "type": "function", + "function": map[string]any{ + "name": t.Name, + "description": t.Description, + "parameters": t.InputSchema, + }, + } + } + result["tools"] = tools + } + if req.ToolChoice != nil { + result["tool_choice"] = encodeToolChoice(req.ToolChoice) + } + + // 公共字段 + if req.UserID != "" { + result["user"] = req.UserID + } + if req.OutputFormat != nil { + result["response_format"] = encodeOutputFormat(req.OutputFormat) + } + if req.ParallelToolUse != nil { + result["parallel_tool_calls"] = *req.ParallelToolUse + } + if req.Thinking != nil { + switch req.Thinking.Type { + case "disabled": + result["reasoning_effort"] = "none" + default: + if req.Thinking.Effort != "" { + result["reasoning_effort"] = req.Thinking.Effort + } else { + result["reasoning_effort"] = "medium" + } + } + } + + body, err := json.Marshal(result) + if err != nil { + return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 OpenAI 请求失败").WithCause(err) + } + return body, nil +} + +// encodeSystemAndMessages 编码系统消息和消息列表 +func encodeSystemAndMessages(req *canonical.CanonicalRequest) []map[string]any { + var messages []map[string]any + + // 系统消息 + switch v := req.System.(type) { + case string: + if v != "" { + messages = append(messages, map[string]any{ + "role": "system", + "content": v, + }) + } + case []canonical.SystemBlock: + var parts []string + for _, b := range v { + parts = append(parts, b.Text) + } + text := joinStrings(parts, "\n\n") + if text != "" { + messages = append(messages, map[string]any{ + "role": "system", + "content": text, + }) + } + } + + // 消息 + for _, msg := range req.Messages { + encoded := encodeMessage(msg) + messages = append(messages, encoded...) + } + + // 合并连续同角色消息 + return mergeConsecutiveRoles(messages) +} + +// encodeMessage 编码单条消息 +func encodeMessage(msg canonical.CanonicalMessage) []map[string]any { + switch msg.Role { + case canonical.RoleUser: + return []map[string]any{{ + "role": "user", + "content": encodeUserContent(msg.Content), + }} + case canonical.RoleAssistant: + m := map[string]any{"role": "assistant"} + var textParts []string + var toolUses []canonical.ContentBlock + + for _, b := range msg.Content { + switch b.Type { + case "text": + textParts = append(textParts, b.Text) + case "tool_use": + toolUses = append(toolUses, b) + } + } + + if len(toolUses) > 0 { + if len(textParts) > 0 { + m["content"] = joinStrings(textParts, "") + } else { + m["content"] = nil + } + tcs := make([]map[string]any, len(toolUses)) + for i, tu := range toolUses { + tcs[i] = map[string]any{ + "id": tu.ID, + "type": "function", + "function": map[string]any{ + "name": tu.Name, + "arguments": string(tu.Input), + }, + } + } + m["tool_calls"] = tcs + } else if len(textParts) > 0 { + m["content"] = joinStrings(textParts, "") + } else { + m["content"] = "" + } + return []map[string]any{m} + + case canonical.RoleTool: + for _, b := range msg.Content { + if b.Type == "tool_result" { + var contentStr string + if b.Content != nil { + var s string + if json.Unmarshal(b.Content, &s) == nil { + contentStr = s + } else { + contentStr = string(b.Content) + } + } + return []map[string]any{{ + "role": "tool", + "tool_call_id": b.ToolUseID, + "content": contentStr, + }} + } + } + } + return nil +} + +// encodeUserContent 编码用户内容 +func encodeUserContent(blocks []canonical.ContentBlock) any { + if len(blocks) == 1 && blocks[0].Type == "text" { + return blocks[0].Text + } + parts := make([]map[string]any, 0, len(blocks)) + for _, b := range blocks { + switch b.Type { + case "text": + parts = append(parts, map[string]any{"type": "text", "text": b.Text}) + case "image": + parts = append(parts, map[string]any{"type": "image_url"}) + case "audio": + parts = append(parts, map[string]any{"type": "input_audio"}) + case "file": + parts = append(parts, map[string]any{"type": "file"}) + } + } + if len(parts) == 0 { + return "" + } + return parts +} + +// encodeToolChoice 编码工具选择 +func encodeToolChoice(choice *canonical.ToolChoice) any { + switch choice.Type { + case "auto": + return "auto" + case "none": + return "none" + case "any": + return "required" + case "tool": + return map[string]any{ + "type": "function", + "function": map[string]any{ + "name": choice.Name, + }, + } + } + return "auto" +} + +// encodeParametersInto 编码参数到结果 map +func encodeParametersInto(req *canonical.CanonicalRequest, result map[string]any) { + if req.Parameters.MaxTokens != nil { + result["max_completion_tokens"] = *req.Parameters.MaxTokens + } + if req.Parameters.Temperature != nil { + result["temperature"] = *req.Parameters.Temperature + } + if req.Parameters.TopP != nil { + result["top_p"] = *req.Parameters.TopP + } + if req.Parameters.FrequencyPenalty != nil { + result["frequency_penalty"] = *req.Parameters.FrequencyPenalty + } + if req.Parameters.PresencePenalty != nil { + result["presence_penalty"] = *req.Parameters.PresencePenalty + } + if len(req.Parameters.StopSequences) > 0 { + result["stop"] = req.Parameters.StopSequences + } +} + +// encodeOutputFormat 编码输出格式 +func encodeOutputFormat(format *canonical.OutputFormat) map[string]any { + switch format.Type { + case "json_object": + return map[string]any{"type": "json_object"} + case "json_schema": + m := map[string]any{"type": "json_schema"} + schema := map[string]any{ + "name": format.Name, + } + if format.Schema != nil { + schema["schema"] = format.Schema + } + if format.Strict != nil { + schema["strict"] = *format.Strict + } + m["json_schema"] = schema + return m + } + return nil +} + +// encodeResponse 将 Canonical 响应编码为 OpenAI 响应 +func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) { + var textParts []string + var thinkingParts []string + var toolUses []canonical.ContentBlock + + for _, b := range resp.Content { + switch b.Type { + case "text": + textParts = append(textParts, b.Text) + case "thinking": + thinkingParts = append(thinkingParts, b.Thinking) + case "tool_use": + toolUses = append(toolUses, b) + } + } + + message := map[string]any{"role": "assistant"} + if len(toolUses) > 0 { + if len(textParts) > 0 { + message["content"] = joinStrings(textParts, "") + } else { + message["content"] = nil + } + tcs := make([]map[string]any, len(toolUses)) + for i, tu := range toolUses { + tcs[i] = map[string]any{ + "id": tu.ID, + "type": "function", + "function": map[string]any{ + "name": tu.Name, + "arguments": string(tu.Input), + }, + } + } + message["tool_calls"] = tcs + } else if len(textParts) > 0 { + message["content"] = joinStrings(textParts, "") + } else { + message["content"] = "" + } + + if len(thinkingParts) > 0 { + message["reasoning_content"] = joinStrings(thinkingParts, "") + } + + var finishReason *string + if resp.StopReason != nil { + fr := mapCanonicalToFinishReason(*resp.StopReason) + finishReason = &fr + } + + result := map[string]any{ + "id": resp.ID, + "object": "chat.completion", + "created": time.Now().Unix(), + "model": resp.Model, + "choices": []map[string]any{{ + "index": 0, + "message": message, + "finish_reason": finishReason, + }}, + "usage": encodeUsage(resp.Usage), + } + + body, err := json.Marshal(result) + if err != nil { + return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 OpenAI 响应失败").WithCause(err) + } + return body, nil +} + +// mapCanonicalToFinishReason 映射 Canonical 停止原因到 OpenAI finish_reason +func mapCanonicalToFinishReason(reason canonical.StopReason) string { + switch reason { + case canonical.StopReasonEndTurn: + return "stop" + case canonical.StopReasonMaxTokens: + return "length" + case canonical.StopReasonToolUse: + return "tool_calls" + case canonical.StopReasonContentFilter: + return "content_filter" + case canonical.StopReasonStopSequence: + return "stop" + case canonical.StopReasonRefusal: + return "stop" + default: + return "stop" + } +} + +// encodeUsage 编码用量 +func encodeUsage(usage canonical.CanonicalUsage) map[string]any { + result := map[string]any{ + "prompt_tokens": usage.InputTokens, + "completion_tokens": usage.OutputTokens, + "total_tokens": usage.InputTokens + usage.OutputTokens, + } + if usage.CacheReadTokens != nil && *usage.CacheReadTokens > 0 { + result["prompt_tokens_details"] = map[string]any{ + "cached_tokens": *usage.CacheReadTokens, + } + } + if usage.ReasoningTokens != nil && *usage.ReasoningTokens > 0 { + result["completion_tokens_details"] = map[string]any{ + "reasoning_tokens": *usage.ReasoningTokens, + } + } + return result +} + +// encodeModelsResponse 编码模型列表响应 +func encodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) { + data := make([]map[string]any, len(list.Models)) + for i, m := range list.Models { + created := int64(0) + if m.Created != 0 { + created = m.Created + } + ownedBy := "unknown" + if m.OwnedBy != "" { + ownedBy = m.OwnedBy + } + data[i] = map[string]any{ + "id": m.ID, + "object": "model", + "created": created, + "owned_by": ownedBy, + } + } + return json.Marshal(map[string]any{ + "object": "list", + "data": data, + }) +} + +// encodeModelInfoResponse 编码模型详情响应 +func encodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) { + created := int64(0) + if info.Created != 0 { + created = info.Created + } + ownedBy := "unknown" + if info.OwnedBy != "" { + ownedBy = info.OwnedBy + } + return json.Marshal(map[string]any{ + "id": info.ID, + "object": "model", + "created": created, + "owned_by": ownedBy, + }) +} + +// encodeEmbeddingRequest 编码嵌入请求 +func encodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) { + result := map[string]any{ + "model": provider.ModelName, + "input": req.Input, + } + if req.EncodingFormat != "" { + result["encoding_format"] = req.EncodingFormat + } + if req.Dimensions != nil { + result["dimensions"] = *req.Dimensions + } + return json.Marshal(result) +} + +// encodeEmbeddingResponse 编码嵌入响应 +func encodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) { + data := make([]map[string]any, len(resp.Data)) + for i, d := range resp.Data { + data[i] = map[string]any{ + "index": d.Index, + "embedding": d.Embedding, + } + } + return json.Marshal(map[string]any{ + "object": "list", + "data": data, + "model": resp.Model, + "usage": resp.Usage, + }) +} + +// encodeRerankRequest 编码重排序请求 +func encodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) { + result := map[string]any{ + "model": provider.ModelName, + "query": req.Query, + "documents": req.Documents, + } + if req.TopN != nil { + result["top_n"] = *req.TopN + } + if req.ReturnDocuments != nil { + result["return_documents"] = *req.ReturnDocuments + } + return json.Marshal(result) +} + +// encodeRerankResponse 编码重排序响应 +func encodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) { + results := make([]map[string]any, len(resp.Results)) + for i, r := range resp.Results { + m := map[string]any{ + "index": r.Index, + "relevance_score": r.RelevanceScore, + } + if r.Document != nil { + m["document"] = *r.Document + } + results[i] = m + } + return json.Marshal(map[string]any{ + "results": results, + "model": resp.Model, + }) +} + +// joinStrings 拼接字符串切片 +func joinStrings(parts []string, sep string) string { + result := "" + for i, p := range parts { + if i > 0 { + result += sep + } + result += p + } + return result +} + +// mergeConsecutiveRoles 合并连续同角色消息(拼接内容) +func mergeConsecutiveRoles(messages []map[string]any) []map[string]any { + if len(messages) <= 1 { + return messages + } + var result []map[string]any + for _, msg := range messages { + if len(result) > 0 { + lastRole := result[len(result)-1]["role"] + currRole := msg["role"] + if lastRole == currRole { + lastContent := result[len(result)-1]["content"] + currContent := msg["content"] + switch lv := lastContent.(type) { + case string: + if cv, ok := currContent.(string); ok { + result[len(result)-1]["content"] = lv + cv + } + case []any: + if cv, ok := currContent.([]any); ok { + result[len(result)-1]["content"] = append(lv, cv...) + } + } + continue + } + } + result = append(result, msg) + } + return result +} diff --git a/backend/internal/conversion/openai/encoder_test.go b/backend/internal/conversion/openai/encoder_test.go new file mode 100644 index 0000000..2e1f5b3 --- /dev/null +++ b/backend/internal/conversion/openai/encoder_test.go @@ -0,0 +1,355 @@ +package openai + +import ( + "encoding/json" + "testing" + + "nex/backend/internal/conversion" + "nex/backend/internal/conversion/canonical" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncodeRequest_Basic(t *testing.T) { + req := &canonical.CanonicalRequest{ + Model: "gpt-4", + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + Stream: true, + } + provider := conversion.NewTargetProvider("", "key", "my-model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + assert.Equal(t, "my-model", result["model"]) + assert.Equal(t, true, result["stream"]) + + msgs, ok := result["messages"].([]any) + require.True(t, ok) + assert.Len(t, msgs, 1) +} + +func TestEncodeRequest_SystemInjection(t *testing.T) { + req := &canonical.CanonicalRequest{ + Model: "gpt-4", + System: "你是助手", + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + msgs := result["messages"].([]any) + assert.Len(t, msgs, 2) + firstMsg := msgs[0].(map[string]any) + assert.Equal(t, "system", firstMsg["role"]) + assert.Equal(t, "你是助手", firstMsg["content"]) +} + +func TestEncodeRequest_ToolCalls(t *testing.T) { + input := json.RawMessage(`{"city":"北京"}`) + req := &canonical.CanonicalRequest{ + Model: "gpt-4", + Messages: []canonical.CanonicalMessage{ + { + Role: canonical.RoleAssistant, + Content: []canonical.ContentBlock{ + canonical.NewToolUseBlock("call_1", "get_weather", input), + }, + }, + }, + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + msgs := result["messages"].([]any) + assistantMsg := msgs[0].(map[string]any) + toolCalls, ok := assistantMsg["tool_calls"].([]any) + require.True(t, ok) + assert.Len(t, toolCalls, 1) + tc := toolCalls[0].(map[string]any) + assert.Equal(t, "call_1", tc["id"]) +} + +func TestEncodeRequest_Thinking(t *testing.T) { + req := &canonical.CanonicalRequest{ + Model: "gpt-4", + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + Thinking: &canonical.ThinkingConfig{Type: "enabled", Effort: "high"}, + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + assert.Equal(t, "high", result["reasoning_effort"]) +} + +func TestEncodeResponse_Basic(t *testing.T) { + sr := canonical.StopReasonEndTurn + resp := &canonical.CanonicalResponse{ + ID: "resp-1", + Model: "gpt-4", + Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")}, + StopReason: &sr, + Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5}, + } + + body, err := encodeResponse(resp) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + assert.Equal(t, "resp-1", result["id"]) + assert.Equal(t, "chat.completion", result["object"]) + + choices := result["choices"].([]any) + choice := choices[0].(map[string]any) + msg := choice["message"].(map[string]any) + assert.Equal(t, "你好", msg["content"]) + assert.Equal(t, "stop", choice["finish_reason"]) +} + +func TestEncodeResponse_ToolUse(t *testing.T) { + sr := canonical.StopReasonToolUse + input := json.RawMessage(`{"q":"test"}`) + resp := &canonical.CanonicalResponse{ + ID: "resp-2", + Model: "gpt-4", + Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)}, + StopReason: &sr, + } + + body, err := encodeResponse(resp) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + choices := result["choices"].([]any) + msg := choices[0].(map[string]any)["message"].(map[string]any) + tcs, ok := msg["tool_calls"].([]any) + require.True(t, ok) + assert.Len(t, tcs, 1) +} + +func TestEncodeModelsResponse(t *testing.T) { + list := &canonical.CanonicalModelList{ + Models: []canonical.CanonicalModel{ + {ID: "gpt-4", Created: 1700000000, OwnedBy: "openai"}, + {ID: "gpt-3.5-turbo", Created: 1700000001, OwnedBy: "openai"}, + }, + } + + body, err := encodeModelsResponse(list) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + assert.Equal(t, "list", result["object"]) + data := result["data"].([]any) + assert.Len(t, data, 2) +} + +func TestMergeConsecutiveRoles(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "A"}, + {"role": "user", "content": "B"}, + {"role": "assistant", "content": "C"}, + {"role": "assistant", "content": "D"}, + } + + result := mergeConsecutiveRoles(messages) + assert.Len(t, result, 2) + assert.Equal(t, "AB", result[0]["content"]) + assert.Equal(t, "CD", result[1]["content"]) +} + +func TestMergeConsecutiveRoles_NotOverwriting(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "你好"}, + {"role": "user", "content": "世界"}, + } + + result := mergeConsecutiveRoles(messages) + assert.Len(t, result, 1) + assert.Equal(t, "你好世界", result[0]["content"]) +} + +func TestEncodeRequest_ToolChoice_Auto(t *testing.T) { + req := &canonical.CanonicalRequest{ + Model: "gpt-4", + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + ToolChoice: canonical.NewToolChoiceAuto(), + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + assert.Equal(t, "auto", result["tool_choice"]) +} + +func TestEncodeRequest_ToolChoice_None(t *testing.T) { + req := &canonical.CanonicalRequest{ + Model: "gpt-4", + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + ToolChoice: canonical.NewToolChoiceNone(), + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + assert.Equal(t, "none", result["tool_choice"]) +} + +func TestEncodeRequest_ToolChoice_Required(t *testing.T) { + req := &canonical.CanonicalRequest{ + Model: "gpt-4", + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + ToolChoice: canonical.NewToolChoiceAny(), + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + assert.Equal(t, "required", result["tool_choice"]) +} + +func TestEncodeRequest_ToolChoice_Named(t *testing.T) { + req := &canonical.CanonicalRequest{ + Model: "gpt-4", + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + ToolChoice: canonical.NewToolChoiceNamed("my_func"), + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + tc, ok := result["tool_choice"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "function", tc["type"]) + fn, ok := tc["function"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "my_func", fn["name"]) +} + +func TestEncodeRequest_OutputFormat_JSONSchema(t *testing.T) { + schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`) + req := &canonical.CanonicalRequest{ + Model: "gpt-4", + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + OutputFormat: &canonical.OutputFormat{ + Type: "json_schema", + Name: "my_schema", + Schema: schema, + }, + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + rf, ok := result["response_format"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "json_schema", rf["type"]) + js, ok := rf["json_schema"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "my_schema", js["name"]) + assert.NotNil(t, js["schema"]) +} + +func TestEncodeRequest_OutputFormat_Text(t *testing.T) { + req := &canonical.CanonicalRequest{ + Model: "gpt-4", + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + _, hasResponseFormat := result["response_format"] + assert.False(t, hasResponseFormat) +} + +func TestEncodeResponse_Thinking(t *testing.T) { + sr := canonical.StopReasonEndTurn + resp := &canonical.CanonicalResponse{ + ID: "resp-thinking", + Model: "gpt-4", + Content: []canonical.ContentBlock{ + canonical.NewTextBlock("回答"), + canonical.NewThinkingBlock("思考过程"), + }, + StopReason: &sr, + Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5}, + } + + body, err := encodeResponse(resp) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + choices := result["choices"].([]any) + msg := choices[0].(map[string]any)["message"].(map[string]any) + assert.Equal(t, "回答", msg["content"]) + assert.Equal(t, "思考过程", msg["reasoning_content"]) +} + +func TestEncodeRequest_Parameters(t *testing.T) { + temp := 0.5 + maxTokens := 2048 + topP := 0.9 + req := &canonical.CanonicalRequest{ + Model: "gpt-4", + Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, + Parameters: canonical.RequestParameters{ + Temperature: &temp, + MaxTokens: &maxTokens, + TopP: &topP, + StopSequences: []string{"STOP", "END"}, + }, + } + provider := conversion.NewTargetProvider("", "key", "model") + + body, err := encodeRequest(req, provider) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(body, &result)) + assert.Equal(t, temp, result["temperature"]) + assert.Equal(t, float64(maxTokens), result["max_completion_tokens"]) + assert.Equal(t, topP, result["top_p"]) + stop, ok := result["stop"].([]any) + require.True(t, ok) + assert.Len(t, stop, 2) + assert.Equal(t, "STOP", stop[0]) + assert.Equal(t, "END", stop[1]) +} diff --git a/backend/internal/conversion/openai/stream_decoder.go b/backend/internal/conversion/openai/stream_decoder.go new file mode 100644 index 0000000..fbff4ca --- /dev/null +++ b/backend/internal/conversion/openai/stream_decoder.go @@ -0,0 +1,230 @@ +package openai + +import ( + "encoding/json" + "fmt" + "strings" + "unicode/utf8" + + "nex/backend/internal/conversion/canonical" +) + +// StreamDecoder OpenAI 流式解码器 +type StreamDecoder struct { + messageStarted bool + openBlocks map[int]string + textBlockIndex int + thinkingBlockIndex int + refusalBlockIndex int + toolCallIDMap map[int]string + toolCallNameMap map[int]string + nextToolCallIdx int + utf8Remainder []byte + accumulatedUsage *canonical.CanonicalUsage +} + +// NewStreamDecoder 创建 OpenAI 流式解码器 +func NewStreamDecoder() *StreamDecoder { + return &StreamDecoder{ + openBlocks: make(map[int]string), + toolCallIDMap: make(map[int]string), + toolCallNameMap: make(map[int]string), + textBlockIndex: -1, + thinkingBlockIndex: -1, + refusalBlockIndex: -1, + } +} + +// ProcessChunk 处理原始 SSE chunk +func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { + // 处理 UTF-8 残余 + data := rawChunk + if len(d.utf8Remainder) > 0 { + data = append(d.utf8Remainder, rawChunk...) + d.utf8Remainder = nil + } + + var events []canonical.CanonicalStreamEvent + + // 解析 SSE data 行 + lines := strings.Split(string(data), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if !strings.HasPrefix(line, "data: ") { + continue + } + payload := strings.TrimPrefix(line, "data: ") + + if payload == "[DONE]" { + events = append(events, d.flushOpenBlocks()...) + return events + } + + chunkEvents := d.processDataChunk([]byte(payload)) + events = append(events, chunkEvents...) + } + + return events +} + +// Flush 刷新解码器状态 +func (d *StreamDecoder) Flush() []canonical.CanonicalStreamEvent { + return nil +} + +// processDataChunk 处理单个 data chunk +func (d *StreamDecoder) processDataChunk(data []byte) []canonical.CanonicalStreamEvent { + // 检查 UTF-8 完整性 + if !utf8.Valid(data) { + validEnd := len(data) + for !utf8.Valid(data[:validEnd]) { + validEnd-- + } + d.utf8Remainder = append(d.utf8Remainder, data[validEnd:]...) + data = data[:validEnd] + } + + var chunk StreamChunk + if err := json.Unmarshal(data, &chunk); err != nil { + return nil + } + + var events []canonical.CanonicalStreamEvent + + // 首个 chunk: MessageStart + if !d.messageStarted { + events = append(events, canonical.NewMessageStartEvent(chunk.ID, chunk.Model)) + d.messageStarted = true + } + + for _, choice := range chunk.Choices { + if choice.Delta == nil { + continue + } + delta := choice.Delta + + // text content + if delta.Content != nil { + text := "" + switch v := delta.Content.(type) { + case string: + text = v + default: + text = fmt.Sprintf("%v", v) + } + if text != "" { + if _, ok := d.openBlocks[d.textBlockIndex]; !ok || d.textBlockIndex < 0 { + d.textBlockIndex = d.allocateBlockIndex() + d.openBlocks[d.textBlockIndex] = "text" + events = append(events, canonical.NewContentBlockStartEvent(d.textBlockIndex, + canonical.StreamContentBlock{Type: "text", Text: ""})) + } + events = append(events, canonical.NewContentBlockDeltaEvent(d.textBlockIndex, + canonical.StreamDelta{Type: string(canonical.DeltaTypeText), Text: text})) + } + } + + // reasoning_content (非标准) + if delta.ReasoningContent != "" { + if _, ok := d.openBlocks[d.thinkingBlockIndex]; !ok || d.thinkingBlockIndex < 0 { + d.thinkingBlockIndex = d.allocateBlockIndex() + d.openBlocks[d.thinkingBlockIndex] = "thinking" + events = append(events, canonical.NewContentBlockStartEvent(d.thinkingBlockIndex, + canonical.StreamContentBlock{Type: "thinking", Thinking: ""})) + } + events = append(events, canonical.NewContentBlockDeltaEvent(d.thinkingBlockIndex, + canonical.StreamDelta{Type: string(canonical.DeltaTypeThinking), Thinking: delta.ReasoningContent})) + } + + // refusal + if delta.Refusal != "" { + if _, ok := d.openBlocks[d.refusalBlockIndex]; !ok || d.refusalBlockIndex < 0 { + d.refusalBlockIndex = d.allocateBlockIndex() + d.openBlocks[d.refusalBlockIndex] = "text" + events = append(events, canonical.NewContentBlockStartEvent(d.refusalBlockIndex, + canonical.StreamContentBlock{Type: "text", Text: ""})) + } + events = append(events, canonical.NewContentBlockDeltaEvent(d.refusalBlockIndex, + canonical.StreamDelta{Type: string(canonical.DeltaTypeText), Text: delta.Refusal})) + } + + // tool_calls + if len(delta.ToolCalls) > 0 { + for _, tc := range delta.ToolCalls { + tcIdx := 0 + if tc.Index != nil { + tcIdx = *tc.Index + } + + if tc.ID != "" { + // 新 tool call block + d.toolCallIDMap[tcIdx] = tc.ID + if tc.Function != nil { + d.toolCallNameMap[tcIdx] = tc.Function.Name + } + blockIdx := d.allocateBlockIndex() + d.openBlocks[blockIdx] = fmt.Sprintf("tool_use_%d", tcIdx) + name := d.toolCallNameMap[tcIdx] + events = append(events, canonical.NewContentBlockStartEvent(blockIdx, + canonical.StreamContentBlock{Type: "tool_use", ID: tc.ID, Name: name})) + } + + // 查找该 tool call 的 block index + blockIdx := d.findToolUseBlockIndex(tcIdx) + if tc.Function != nil && tc.Function.Arguments != "" { + events = append(events, canonical.NewContentBlockDeltaEvent(blockIdx, + canonical.StreamDelta{Type: string(canonical.DeltaTypeInputJSON), PartialJSON: tc.Function.Arguments})) + } + } + } + + // finish_reason + if choice.FinishReason != nil && *choice.FinishReason != "" { + events = append(events, d.flushOpenBlocks()...) + sr := mapFinishReason(*choice.FinishReason) + events = append(events, canonical.NewMessageDeltaEventWithUsage(sr, nil)) + events = append(events, canonical.NewMessageStopEvent()) + } + } + + // usage chunk (choices 为空) + if len(chunk.Choices) == 0 && chunk.Usage != nil { + usage := decodeUsage(chunk.Usage) + d.accumulatedUsage = &usage + events = append(events, canonical.NewMessageDeltaEventWithUsage(canonical.StopReasonEndTurn, &usage)) + } + + return events +} + +// allocateBlockIndex 分配 block 索引 +func (d *StreamDecoder) allocateBlockIndex() int { + maxIdx := -1 + for k := range d.openBlocks { + if k > maxIdx { + maxIdx = k + } + } + return maxIdx + 1 +} + +// findToolUseBlockIndex 查找 tool use block 索引 +func (d *StreamDecoder) findToolUseBlockIndex(tcIdx int) int { + key := fmt.Sprintf("tool_use_%d", tcIdx) + for blockIdx, typ := range d.openBlocks { + if typ == key { + return blockIdx + } + } + return d.allocateBlockIndex() +} + +// flushOpenBlocks 关闭所有 open blocks +func (d *StreamDecoder) flushOpenBlocks() []canonical.CanonicalStreamEvent { + var events []canonical.CanonicalStreamEvent + for idx := range d.openBlocks { + events = append(events, canonical.NewContentBlockStopEvent(idx)) + } + d.openBlocks = make(map[int]string) + return events +} diff --git a/backend/internal/conversion/openai/stream_decoder_test.go b/backend/internal/conversion/openai/stream_decoder_test.go new file mode 100644 index 0000000..21a147d --- /dev/null +++ b/backend/internal/conversion/openai/stream_decoder_test.go @@ -0,0 +1,355 @@ +package openai + +import ( + "encoding/json" + "testing" + + "nex/backend/internal/conversion/canonical" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func makeSSEData(payload string) []byte { + return []byte("data: " + payload + "\n\n") +} + +func TestStreamDecoder_BasicText(t *testing.T) { + d := NewStreamDecoder() + + chunk := map[string]any{ + "id": "chatcmpl-1", + "object": "chat.completion.chunk", + "model": "gpt-4", + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{"content": "你好"}, + }, + }, + } + data, _ := json.Marshal(chunk) + raw := makeSSEData(string(data)) + + events := d.ProcessChunk(raw) + require.NotEmpty(t, events) + + foundStart := false + foundDelta := false + for _, e := range events { + if e.Type == canonical.EventMessageStart { + foundStart = true + assert.Equal(t, "chatcmpl-1", e.Message.ID) + } + if e.Type == canonical.EventContentBlockDelta && e.Delta != nil { + foundDelta = true + assert.Equal(t, "text_delta", e.Delta.Type) + assert.Equal(t, "你好", e.Delta.Text) + } + } + assert.True(t, foundStart) + assert.True(t, foundDelta) +} + +func TestStreamDecoder_ToolCalls(t *testing.T) { + d := NewStreamDecoder() + + idx := 0 + chunk := map[string]any{ + "id": "chatcmpl-1", + "model": "gpt-4", + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{ + "tool_calls": []any{ + map[string]any{ + "index": &idx, + "id": "call_1", + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "arguments": "{\"city\":\"北京\"}", + }, + }, + }, + }, + }, + }, + } + data, _ := json.Marshal(chunk) + raw := makeSSEData(string(data)) + + events := d.ProcessChunk(raw) + require.NotEmpty(t, events) + + found := false + for _, e := range events { + if e.Type == canonical.EventContentBlockStart && e.ContentBlock != nil && e.ContentBlock.Type == "tool_use" { + found = true + assert.Equal(t, "call_1", e.ContentBlock.ID) + assert.Equal(t, "get_weather", e.ContentBlock.Name) + } + } + assert.True(t, found) +} + +func TestStreamDecoder_Thinking(t *testing.T) { + d := NewStreamDecoder() + + chunk := map[string]any{ + "id": "chatcmpl-1", + "model": "gpt-4", + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{ + "reasoning_content": "思考中", + }, + }, + }, + } + data, _ := json.Marshal(chunk) + raw := makeSSEData(string(data)) + + events := d.ProcessChunk(raw) + found := false + for _, e := range events { + if e.Type == canonical.EventContentBlockDelta && e.Delta != nil && e.Delta.Type == "thinking_delta" { + found = true + assert.Equal(t, "思考中", e.Delta.Thinking) + } + } + assert.True(t, found) +} + +func TestStreamDecoder_FinishReason(t *testing.T) { + d := NewStreamDecoder() + + chunk := map[string]any{ + "id": "chatcmpl-1", + "model": "gpt-4", + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{}, + "finish_reason": "stop", + }, + }, + } + data, _ := json.Marshal(chunk) + raw := makeSSEData(string(data)) + + events := d.ProcessChunk(raw) + foundStop := false + foundMsgStop := false + for _, e := range events { + if e.Type == canonical.EventMessageDelta && e.StopReason != nil { + foundStop = true + assert.Equal(t, canonical.StopReasonEndTurn, *e.StopReason) + } + if e.Type == canonical.EventMessageStop { + foundMsgStop = true + } + } + assert.True(t, foundStop) + assert.True(t, foundMsgStop) +} + +func TestStreamDecoder_DoneSignal(t *testing.T) { + d := NewStreamDecoder() + + // 先发送一个文本 chunk + chunk := map[string]any{ + "id": "chatcmpl-1", + "model": "gpt-4", + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{"content": "hi"}, + }, + }, + } + data, _ := json.Marshal(chunk) + raw := append(makeSSEData(string(data)), []byte("data: [DONE]\n\n")...) + + events := d.ProcessChunk(raw) + // 应该包含 block stop 事件([DONE] 触发 flushOpenBlocks) + foundBlockStop := false + for _, e := range events { + if e.Type == canonical.EventContentBlockStop { + foundBlockStop = true + } + } + assert.True(t, foundBlockStop) +} + +func TestStreamDecoder_RefusalReuse(t *testing.T) { + d := NewStreamDecoder() + + // 连续两个 refusal delta chunk + for _, text := range []string{"拒绝", "原因"} { + chunk := map[string]any{ + "id": "chatcmpl-1", + "model": "gpt-4", + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{"refusal": text}, + }, + }, + } + data, _ := json.Marshal(chunk) + raw := makeSSEData(string(data)) + events := d.ProcessChunk(raw) + _ = events + } + + // 检查只创建了一个 text block(refusal 复用同一个 block) + assert.Contains(t, d.openBlocks, d.refusalBlockIndex) +} + +func makeChunkSSE(chunk map[string]any) []byte { + data, _ := json.Marshal(chunk) + return []byte("data: " + string(data) + "\n\n") +} + +func TestStreamDecoder_UsageChunk(t *testing.T) { + d := NewStreamDecoder() + + chunk := map[string]any{ + "id": "chatcmpl-usage", + "object": "chat.completion.chunk", + "model": "gpt-4", + "choices": []any{}, + "usage": map[string]any{ + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + }, + } + raw := makeChunkSSE(chunk) + + events := d.ProcessChunk(raw) + require.NotEmpty(t, events) + + found := false + for _, e := range events { + if e.Type == canonical.EventMessageDelta { + found = true + require.NotNil(t, e.Usage) + assert.Equal(t, 100, e.Usage.InputTokens) + assert.Equal(t, 50, e.Usage.OutputTokens) + } + } + assert.True(t, found) +} + +func TestStreamDecoder_MultipleToolCalls(t *testing.T) { + d := NewStreamDecoder() + + idx0 := 0 + chunk1 := map[string]any{ + "id": "chatcmpl-mt", + "model": "gpt-4", + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{ + "tool_calls": []any{ + map[string]any{ + "index": &idx0, + "id": "call_a", + "type": "function", + "function": map[string]any{ + "name": "func_a", + "arguments": "{}", + }, + }, + }, + }, + }, + }, + } + + idx1 := 1 + chunk2 := map[string]any{ + "id": "chatcmpl-mt", + "model": "gpt-4", + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{ + "tool_calls": []any{ + map[string]any{ + "index": &idx1, + "id": "call_b", + "type": "function", + "function": map[string]any{ + "name": "func_b", + "arguments": "{}", + }, + }, + }, + }, + }, + }, + } + + events1 := d.ProcessChunk(makeChunkSSE(chunk1)) + require.NotEmpty(t, events1) + + events2 := d.ProcessChunk(makeChunkSSE(chunk2)) + require.NotEmpty(t, events2) + + blockIndices := map[int]bool{} + for _, e := range append(events1, events2...) { + if e.Type == canonical.EventContentBlockStart && e.ContentBlock != nil && e.ContentBlock.Type == "tool_use" { + require.NotNil(t, e.Index) + blockIndices[*e.Index] = true + } + } + assert.Equal(t, 2, len(blockIndices), "两个 tool call 应分配不同的 block 索引") +} + +func TestStreamDecoder_Flush(t *testing.T) { + d := NewStreamDecoder() + result := d.Flush() + assert.Nil(t, result) +} + +func TestStreamDecoder_MultipleChunks_Text(t *testing.T) { + d := NewStreamDecoder() + + chunk1 := map[string]any{ + "id": "chatcmpl-multi", + "model": "gpt-4", + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{"content": "你好"}, + }, + }, + } + chunk2 := map[string]any{ + "id": "chatcmpl-multi", + "model": "gpt-4", + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{"content": "世界"}, + }, + }, + } + + raw := append(makeChunkSSE(chunk1), makeChunkSSE(chunk2)...) + events := d.ProcessChunk(raw) + + deltas := []string{} + for _, e := range events { + if e.Type == canonical.EventContentBlockDelta && e.Delta != nil && e.Delta.Type == "text_delta" { + deltas = append(deltas, e.Delta.Text) + } + } + assert.Equal(t, []string{"你好", "世界"}, deltas) +} diff --git a/backend/internal/conversion/openai/stream_encoder.go b/backend/internal/conversion/openai/stream_encoder.go new file mode 100644 index 0000000..775fe18 --- /dev/null +++ b/backend/internal/conversion/openai/stream_encoder.go @@ -0,0 +1,217 @@ +package openai + +import ( + "encoding/json" + "fmt" + "time" + + "nex/backend/internal/conversion/canonical" +) + +// StreamEncoder OpenAI 流式编码器 +type StreamEncoder struct { + bufferedStart *canonical.CanonicalStreamEvent + toolCallIndexMap map[string]int + nextToolCallIndex int +} + +// NewStreamEncoder 创建 OpenAI 流式编码器 +func NewStreamEncoder() *StreamEncoder { + return &StreamEncoder{ + toolCallIndexMap: make(map[string]int), + } +} + +// EncodeEvent 编码 Canonical 事件为 SSE chunk +func (e *StreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { + switch event.Type { + case canonical.EventMessageStart: + return e.encodeMessageStart(event) + case canonical.EventContentBlockStart: + return e.bufferBlockStart(event) + case canonical.EventContentBlockDelta: + return e.encodeContentBlockDelta(event) + case canonical.EventContentBlockStop: + return nil + case canonical.EventMessageDelta: + return e.encodeMessageDelta(event) + case canonical.EventMessageStop: + return [][]byte{[]byte("data: [DONE]\n\n")} + case canonical.EventPing, canonical.EventError: + return nil + } + return nil +} + +// Flush 刷新缓冲区 +func (e *StreamEncoder) Flush() [][]byte { + return nil +} + +// encodeMessageStart 编码消息开始事件 +func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent) [][]byte { + id := "" + model := "" + if event.Message != nil { + id = event.Message.ID + model = event.Message.Model + } + + chunk := map[string]any{ + "id": id, + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]any{{ + "index": 0, + "delta": map[string]any{"role": "assistant"}, + }}, + } + + return e.marshalChunk(chunk) +} + +// bufferBlockStart 缓冲 block start 事件 +func (e *StreamEncoder) bufferBlockStart(event canonical.CanonicalStreamEvent) [][]byte { + e.bufferedStart = &event + if event.ContentBlock != nil && event.ContentBlock.Type == "tool_use" { + idx := e.nextToolCallIndex + e.nextToolCallIndex++ + if event.ContentBlock.ID != "" { + e.toolCallIndexMap[event.ContentBlock.ID] = idx + } + } + return nil +} + +// encodeContentBlockDelta 编码内容块增量事件 +func (e *StreamEncoder) encodeContentBlockDelta(event canonical.CanonicalStreamEvent) [][]byte { + if event.Delta == nil { + return nil + } + + switch canonical.DeltaType(event.Delta.Type) { + case canonical.DeltaTypeText: + return e.encodeTextDelta(event) + case canonical.DeltaTypeInputJSON: + return e.encodeInputJSONDelta(event) + case canonical.DeltaTypeThinking: + return e.encodeThinkingDelta(event) + } + return nil +} + +// encodeTextDelta 编码文本增量 +func (e *StreamEncoder) encodeTextDelta(event canonical.CanonicalStreamEvent) [][]byte { + delta := map[string]any{ + "content": event.Delta.Text, + } + if e.bufferedStart != nil { + e.bufferedStart = nil + } + return e.encodeDelta(delta) +} + +// encodeInputJSONDelta 编码 JSON 输入增量 +func (e *StreamEncoder) encodeInputJSONDelta(event canonical.CanonicalStreamEvent) [][]byte { + if e.bufferedStart != nil && e.bufferedStart.ContentBlock != nil { + // 首次 delta,含 id 和 name + start := e.bufferedStart.ContentBlock + tcIdx := 0 + if start.ID != "" { + tcIdx = e.toolCallIndexMap[start.ID] + } + delta := map[string]any{ + "tool_calls": []map[string]any{{ + "index": tcIdx, + "id": start.ID, + "type": "function", + "function": map[string]any{ + "name": start.Name, + "arguments": event.Delta.PartialJSON, + }, + }}, + } + e.bufferedStart = nil + return e.encodeDelta(delta) + } + + // 后续 delta,仅含 arguments + // 通过 index 查找 tool call + tcIdx := 0 + if event.Index != nil { + for id, idx := range e.toolCallIndexMap { + if idx == tcIdx { + _ = id + break + } + } + } + delta := map[string]any{ + "tool_calls": []map[string]any{{ + "index": tcIdx, + "function": map[string]any{ + "arguments": event.Delta.PartialJSON, + }, + }}, + } + return e.encodeDelta(delta) +} + +// encodeThinkingDelta 编码思考增量 +func (e *StreamEncoder) encodeThinkingDelta(event canonical.CanonicalStreamEvent) [][]byte { + delta := map[string]any{ + "reasoning_content": event.Delta.Thinking, + } + if e.bufferedStart != nil { + e.bufferedStart = nil + } + return e.encodeDelta(delta) +} + +// encodeMessageDelta 编码消息增量事件 +func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent) [][]byte { + var chunks [][]byte + + if event.StopReason != nil { + fr := mapCanonicalToFinishReason(*event.StopReason) + chunk := map[string]any{ + "choices": []map[string]any{{ + "index": 0, + "delta": map[string]any{}, + "finish_reason": fr, + }}, + } + chunks = append(chunks, e.marshalChunk(chunk)...) + } + + if event.Usage != nil { + chunk := map[string]any{ + "choices": []map[string]any{}, + "usage": encodeUsage(*event.Usage), + } + chunks = append(chunks, e.marshalChunk(chunk)...) + } + + return chunks +} + +// encodeDelta 编码 delta 到 SSE chunk +func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte { + chunk := map[string]any{ + "choices": []map[string]any{{ + "index": 0, + "delta": delta, + }}, + } + return e.marshalChunk(chunk) +} + +// marshalChunk 序列化 chunk 为 SSE data +func (e *StreamEncoder) marshalChunk(chunk map[string]any) [][]byte { + data, err := json.Marshal(chunk) + if err != nil { + return nil + } + return [][]byte{[]byte(fmt.Sprintf("data: %s\n\n", data))} +} diff --git a/backend/internal/conversion/openai/stream_encoder_test.go b/backend/internal/conversion/openai/stream_encoder_test.go new file mode 100644 index 0000000..83ca20f --- /dev/null +++ b/backend/internal/conversion/openai/stream_encoder_test.go @@ -0,0 +1,172 @@ +package openai + +import ( + "encoding/json" + "strings" + "testing" + + "nex/backend/internal/conversion/canonical" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStreamEncoder_MessageStart(t *testing.T) { + e := NewStreamEncoder() + event := canonical.NewMessageStartEvent("chatcmpl-1", "gpt-4") + + chunks := e.EncodeEvent(event) + require.Len(t, chunks, 1) + + s := string(chunks[0]) + assert.True(t, strings.HasPrefix(s, "data: ")) + assert.Contains(t, s, "chatcmpl-1") + assert.Contains(t, s, "chat.completion.chunk") + + var payload map[string]any + data := strings.TrimPrefix(s, "data: ") + data = strings.TrimRight(data, "\n") + require.NoError(t, json.Unmarshal([]byte(data), &payload)) + choices := payload["choices"].([]any) + delta := choices[0].(map[string]any)["delta"].(map[string]any) + assert.Equal(t, "assistant", delta["role"]) +} + +func TestStreamEncoder_TextDelta(t *testing.T) { + e := NewStreamEncoder() + event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{Type: "text_delta", Text: "你好"}) + + chunks := e.EncodeEvent(event) + require.Len(t, chunks, 1) + + s := string(chunks[0]) + assert.Contains(t, s, "你好") +} + +func TestStreamEncoder_MessageStop(t *testing.T) { + e := NewStreamEncoder() + event := canonical.NewMessageStopEvent() + + chunks := e.EncodeEvent(event) + require.Len(t, chunks, 1) + assert.Equal(t, "data: [DONE]\n\n", string(chunks[0])) +} + +func TestStreamEncoder_Buffering(t *testing.T) { + e := NewStreamEncoder() + + // ContentBlockStart 应被缓冲,不输出 + startEvent := canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: ""}) + chunks := e.EncodeEvent(startEvent) + assert.Nil(t, chunks) + assert.NotNil(t, e.bufferedStart) + + // 第一个 delta 触发输出(清空缓冲) + deltaEvent := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{Type: "text_delta", Text: "hello"}) + chunks = e.EncodeEvent(deltaEvent) + require.NotEmpty(t, chunks) + assert.Nil(t, e.bufferedStart) +} + +func TestStreamEncoder_ContentBlockStop_ReturnsNil(t *testing.T) { + e := NewStreamEncoder() + idx := 0 + event := canonical.CanonicalStreamEvent{ + Type: canonical.EventContentBlockStop, + Index: &idx, + } + chunks := e.EncodeEvent(event) + assert.Nil(t, chunks) +} + +func TestStreamEncoder_Ping_ReturnsNil(t *testing.T) { + e := NewStreamEncoder() + event := canonical.NewPingEvent() + chunks := e.EncodeEvent(event) + assert.Nil(t, chunks) +} + +func TestStreamEncoder_Error_ReturnsNil(t *testing.T) { + e := NewStreamEncoder() + event := canonical.NewErrorEvent("test_error", "测试错误") + chunks := e.EncodeEvent(event) + assert.Nil(t, chunks) +} + +func TestStreamEncoder_Flush_ReturnsNil(t *testing.T) { + e := NewStreamEncoder() + chunks := e.Flush() + assert.Nil(t, chunks) +} + +func TestStreamEncoder_ThinkingDelta(t *testing.T) { + e := NewStreamEncoder() + event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{ + Type: string(canonical.DeltaTypeThinking), + Thinking: "思考内容", + }) + + chunks := e.EncodeEvent(event) + require.Len(t, chunks, 1) + + s := string(chunks[0]) + assert.Contains(t, s, "reasoning_content") + assert.Contains(t, s, "思考内容") +} + +func TestStreamEncoder_InputJSONDelta(t *testing.T) { + e := NewStreamEncoder() + + e.EncodeEvent(canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{ + Type: "tool_use", + ID: "call_1", + Name: "get_weather", + })) + + event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{ + Type: string(canonical.DeltaTypeInputJSON), + PartialJSON: "{\"city\":\"北京\"}", + }) + + chunks := e.EncodeEvent(event) + require.NotEmpty(t, chunks) + + s := string(chunks[0]) + assert.Contains(t, s, "tool_calls") + assert.Contains(t, s, "北京") +} + +func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) { + e := NewStreamEncoder() + sr := canonical.StopReasonEndTurn + event := canonical.CanonicalStreamEvent{ + Type: canonical.EventMessageDelta, + StopReason: &sr, + } + + chunks := e.EncodeEvent(event) + require.NotEmpty(t, chunks) + + s := string(chunks[0]) + assert.Contains(t, s, "finish_reason") + assert.Contains(t, s, "stop") +} + +func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) { + e := NewStreamEncoder() + usage := canonical.CanonicalUsage{ + InputTokens: 100, + OutputTokens: 50, + } + event := canonical.CanonicalStreamEvent{ + Type: canonical.EventMessageDelta, + Usage: &usage, + } + + chunks := e.EncodeEvent(event) + require.NotEmpty(t, chunks) + + s := string(chunks[0]) + assert.Contains(t, s, "usage") + assert.Contains(t, s, "prompt_tokens") +} diff --git a/backend/internal/conversion/openai/types.go b/backend/internal/conversion/openai/types.go new file mode 100644 index 0000000..08add3f --- /dev/null +++ b/backend/internal/conversion/openai/types.go @@ -0,0 +1,245 @@ +package openai + +import "encoding/json" + +// ChatCompletionRequest OpenAI Chat Completion 请求 +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + Stop any `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + User string `json:"user,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` + N *int `json:"n,omitempty"` + Seed *int `json:"seed,omitempty"` + Logprobs *bool `json:"logprobs,omitempty"` + TopLogprobs *int `json:"top_logprobs,omitempty"` + + // 已废弃字段 + Functions []FunctionDef `json:"functions,omitempty"` + FunctionCall any `json:"function_call,omitempty"` +} + +// Message OpenAI 消息 +type Message struct { + Role string `json:"role"` + Content any `json:"content"` + Name string `json:"name,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + Refusal string `json:"refusal,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + + // 已废弃 + FunctionCall *FunctionCallMsg `json:"function_call,omitempty"` +} + +// ToolCall OpenAI 工具调用 +type ToolCall struct { + Index *int `json:"index,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Function *FunctionCall `json:"function,omitempty"` + Custom *CustomTool `json:"custom,omitempty"` +} + +// FunctionCall OpenAI 函数调用 +type FunctionCall struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` +} + +// CustomTool 自定义工具 +type CustomTool struct { + Name string `json:"name"` + Input string `json:"input"` +} + +// FunctionCallMsg 已废弃的函数调用消息 +type FunctionCallMsg struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// Tool OpenAI 工具定义 +type Tool struct { + Type string `json:"type"` + Function *FunctionDef `json:"function,omitempty"` +} + +// FunctionDef OpenAI 函数定义 +type FunctionDef struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` + Strict *bool `json:"strict,omitempty"` +} + +// ResponseFormat OpenAI 响应格式 +type ResponseFormat struct { + Type string `json:"type"` + JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"` +} + +// JSONSchemaDef JSON Schema 定义 +type JSONSchemaDef struct { + Name string `json:"name"` + Schema json.RawMessage `json:"schema,omitempty"` + Strict *bool `json:"strict,omitempty"` +} + +// StreamOptions 流式选项 +type StreamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` +} + +// ChatCompletionResponse OpenAI Chat Completion 响应 +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []Choice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` +} + +// Choice OpenAI 选择项 +type Choice struct { + Index int `json:"index"` + Message *Message `json:"message,omitempty"` + Delta *Message `json:"delta,omitempty"` + FinishReason *string `json:"finish_reason"` + Logprobs any `json:"logprobs,omitempty"` +} + +// Usage OpenAI 用量 +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"` + CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` +} + +// PromptTokensDetails 提示 Token 详情 +type PromptTokensDetails struct { + CachedTokens int `json:"cached_tokens,omitempty"` + AudioTokens int `json:"audio_tokens,omitempty"` +} + +// CompletionTokensDetails 完成 Token 详情 +type CompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens,omitempty"` + AudioTokens int `json:"audio_tokens,omitempty"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"` +} + +// StreamChunk OpenAI 流式 chunk +type StreamChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []Choice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` +} + +// ModelsResponse OpenAI 模型列表响应 +type ModelsResponse struct { + Object string `json:"object"` + Data []ModelItem `json:"data"` +} + +// ModelItem OpenAI 模型项 +type ModelItem struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` +} + +// ModelInfoResponse OpenAI 模型详情响应 +type ModelInfoResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` +} + +// EmbeddingRequest OpenAI 嵌入请求 +type EmbeddingRequest struct { + Model string `json:"model"` + Input any `json:"input"` + EncodingFormat string `json:"encoding_format,omitempty"` + Dimensions *int `json:"dimensions,omitempty"` +} + +// EmbeddingResponse OpenAI 嵌入响应 +type EmbeddingResponse struct { + Object string `json:"object"` + Data []EmbeddingData `json:"data"` + Model string `json:"model"` + Usage EmbeddingUsage `json:"usage"` +} + +// EmbeddingData 嵌入数据项 +type EmbeddingData struct { + Index int `json:"index"` + Embedding any `json:"embedding"` +} + +// EmbeddingUsage 嵌入用量 +type EmbeddingUsage struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// RerankRequest OpenAI 重排序请求 +type RerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + Documents []string `json:"documents"` + TopN *int `json:"top_n,omitempty"` + ReturnDocuments *bool `json:"return_documents,omitempty"` +} + +// RerankResponse OpenAI 重排序响应 +type RerankResponse struct { + Results []RerankResult `json:"results"` + Model string `json:"model"` +} + +// RerankResult 重排序结果项 +type RerankResult struct { + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` + Document *string `json:"document,omitempty"` +} + +// ErrorResponse OpenAI 错误响应 +type ErrorResponse struct { + Error ErrorDetail `json:"error"` +} + +// ErrorDetail 错误详情 +type ErrorDetail struct { + Message string `json:"message"` + Type string `json:"type"` + Param any `json:"param"` + Code string `json:"code"` +} diff --git a/backend/internal/conversion/provider.go b/backend/internal/conversion/provider.go new file mode 100644 index 0000000..991b84d --- /dev/null +++ b/backend/internal/conversion/provider.go @@ -0,0 +1,19 @@ +package conversion + +// TargetProvider 目标上游供应商信息 +type TargetProvider struct { + BaseURL string `json:"base_url"` + APIKey string `json:"api_key"` + ModelName string `json:"model_name"` + AdapterConfig map[string]any `json:"adapter_config,omitempty"` +} + +// NewTargetProvider 创建目标供应商 +func NewTargetProvider(baseURL, apiKey, modelName string) *TargetProvider { + return &TargetProvider{ + BaseURL: baseURL, + APIKey: apiKey, + ModelName: modelName, + AdapterConfig: make(map[string]any), + } +} diff --git a/backend/internal/conversion/stream.go b/backend/internal/conversion/stream.go new file mode 100644 index 0000000..b4def9f --- /dev/null +++ b/backend/internal/conversion/stream.go @@ -0,0 +1,107 @@ +package conversion + +import "nex/backend/internal/conversion/canonical" + +// StreamDecoder 流式解码器接口 +type StreamDecoder interface { + ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent + Flush() []canonical.CanonicalStreamEvent +} + +// StreamEncoder 流式编码器接口 +type StreamEncoder interface { + EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte + Flush() [][]byte +} + +// StreamConverter 流式转换器接口 +type StreamConverter interface { + ProcessChunk(rawChunk []byte) [][]byte + Flush() [][]byte +} + +// PassthroughStreamConverter 同协议透传流式转换器 +type PassthroughStreamConverter struct{} + +// NewPassthroughStreamConverter 创建透传流式转换器 +func NewPassthroughStreamConverter() *PassthroughStreamConverter { + return &PassthroughStreamConverter{} +} + +// ProcessChunk 直接传递原始字节 +func (c *PassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte { + return [][]byte{rawChunk} +} + +// Flush 无缓冲数据 +func (c *PassthroughStreamConverter) Flush() [][]byte { + return nil +} + +// CanonicalStreamConverter 跨协议规范流式转换器 +type CanonicalStreamConverter struct { + decoder StreamDecoder + encoder StreamEncoder + chain *MiddlewareChain + ctx ConversionContext + clientProtocol string + providerProtocol string +} + +// NewCanonicalStreamConverter 创建规范流式转换器 +func NewCanonicalStreamConverter(decoder StreamDecoder, encoder StreamEncoder) *CanonicalStreamConverter { + return &CanonicalStreamConverter{ + decoder: decoder, + encoder: encoder, + } +} + +// NewCanonicalStreamConverterWithMiddleware 创建带中间件的规范流式转换器 +func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol string) *CanonicalStreamConverter { + return &CanonicalStreamConverter{ + decoder: decoder, + encoder: encoder, + chain: chain, + ctx: ctx, + clientProtocol: clientProtocol, + providerProtocol: providerProtocol, + } +} + +// ProcessChunk 解码 → 中间件 → 编码管道 +func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte { + events := c.decoder.ProcessChunk(rawChunk) + var result [][]byte + for i := range events { + if c.chain != nil { + processed, err := c.chain.ApplyStreamEvent(&events[i], c.clientProtocol, c.providerProtocol, &c.ctx) + if err != nil { + continue + } + events[i] = *processed + } + chunks := c.encoder.EncodeEvent(events[i]) + result = append(result, chunks...) + } + return result +} + +// Flush 刷新解码器和编码器缓冲区 +func (c *CanonicalStreamConverter) Flush() [][]byte { + events := c.decoder.Flush() + var result [][]byte + for i := range events { + if c.chain != nil { + processed, err := c.chain.ApplyStreamEvent(&events[i], c.clientProtocol, c.providerProtocol, &c.ctx) + if err != nil { + continue + } + events[i] = *processed + } + chunks := c.encoder.EncodeEvent(events[i]) + result = append(result, chunks...) + } + encoderChunks := c.encoder.Flush() + result = append(result, encoderChunks...) + return result +} diff --git a/backend/internal/conversion/stream_test.go b/backend/internal/conversion/stream_test.go new file mode 100644 index 0000000..4531cc5 --- /dev/null +++ b/backend/internal/conversion/stream_test.go @@ -0,0 +1,130 @@ +package conversion + +import ( + "testing" + + "nex/backend/internal/conversion/canonical" + + "github.com/stretchr/testify/assert" +) + +func TestPassthroughStreamConverter_ProcessChunk(t *testing.T) { + converter := NewPassthroughStreamConverter() + data := []byte("hello world") + result := converter.ProcessChunk(data) + assert.Len(t, result, 1) + assert.Equal(t, data, result[0]) +} + +func TestPassthroughStreamConverter_Flush(t *testing.T) { + converter := NewPassthroughStreamConverter() + result := converter.Flush() + assert.Nil(t, result) +} + +// mockStreamDecoder 模拟流式解码器 +type mockStreamDecoder struct { + chunks [][]canonical.CanonicalStreamEvent + flush []canonical.CanonicalStreamEvent +} + +// ProcessChunk 弹出下一个分片的事件 +func (d *mockStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { + if len(d.chunks) == 0 { + return nil + } + events := d.chunks[0] + d.chunks = d.chunks[1:] + return events +} + +// Flush 返回刷新事件 +func (d *mockStreamDecoder) Flush() []canonical.CanonicalStreamEvent { + return d.flush +} + +// mockStreamEncoder 模拟流式编码器 +type mockStreamEncoder struct { + events [][]byte + flush [][]byte +} + +// EncodeEvent 返回编码后的事件 +func (e *mockStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { + if len(e.events) == 0 { + return nil + } + return e.events +} + +// Flush 返回编码器刷新数据 +func (e *mockStreamEncoder) Flush() [][]byte { + return e.flush +} + +func TestCanonicalStreamConverter_ProcessChunk(t *testing.T) { + event := canonical.NewMessageStartEvent("id-1", "gpt-4") + decoder := &mockStreamDecoder{ + chunks: [][]canonical.CanonicalStreamEvent{{event}}, + } + encoder := &mockStreamEncoder{ + events: [][]byte{[]byte("data: test\n\n")}, + } + + converter := NewCanonicalStreamConverter(decoder, encoder) + result := converter.ProcessChunk([]byte("raw")) + + assert.Len(t, result, 1) + assert.Equal(t, []byte("data: test\n\n"), result[0]) +} + +func TestCanonicalStreamConverter_WithMiddleware(t *testing.T) { + var records []string + event := canonical.NewMessageStartEvent("id-1", "gpt-4") + decoder := &mockStreamDecoder{ + chunks: [][]canonical.CanonicalStreamEvent{{event}}, + } + encoder := &mockStreamEncoder{ + events: [][]byte{[]byte("data: ok\n\n")}, + } + + chain := NewMiddlewareChain() + chain.Use(&recordingMiddleware{name: "mw1", records: &records}) + ctx := NewConversionContext(InterfaceTypeChat) + + converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic") + result := converter.ProcessChunk([]byte("raw")) + + assert.Len(t, result, 1) + assert.Equal(t, []string{"stream:mw1"}, records) + assert.Equal(t, []byte("data: ok\n\n"), result[0]) +} + +func TestCanonicalStreamConverter_Flush(t *testing.T) { + decoder := &mockStreamDecoder{ + flush: []canonical.CanonicalStreamEvent{ + canonical.NewMessageStopEvent(), + }, + } + encoder := &mockStreamEncoder{ + events: [][]byte{[]byte("data: stop\n\n")}, + flush: [][]byte{[]byte("data: flush\n\n")}, + } + + converter := NewCanonicalStreamConverter(decoder, encoder) + result := converter.Flush() + + assert.Len(t, result, 2) + assert.Equal(t, []byte("data: stop\n\n"), result[0]) + assert.Equal(t, []byte("data: flush\n\n"), result[1]) +} + +func TestCanonicalStreamConverter_EmptyDecoder(t *testing.T) { + decoder := &mockStreamDecoder{} + encoder := &mockStreamEncoder{} + + converter := NewCanonicalStreamConverter(decoder, encoder) + result := converter.ProcessChunk([]byte("raw")) + + assert.Nil(t, result) +} diff --git a/backend/internal/domain/provider.go b/backend/internal/domain/provider.go index 199d18e..f01b2d8 100644 --- a/backend/internal/domain/provider.go +++ b/backend/internal/domain/provider.go @@ -8,6 +8,7 @@ type Provider struct { Name string `json:"name"` APIKey string `json:"api_key"` BaseURL string `json:"base_url"` + Protocol string `json:"protocol"` Enabled bool `json:"enabled"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` diff --git a/backend/internal/handler/anthropic_handler.go b/backend/internal/handler/anthropic_handler.go deleted file mode 100644 index 13e5aef..0000000 --- a/backend/internal/handler/anthropic_handler.go +++ /dev/null @@ -1,217 +0,0 @@ -package handler - -import ( - "bufio" - "fmt" - "net/http" - - "github.com/gin-gonic/gin" - - appErrors "nex/backend/pkg/errors" - - "nex/backend/internal/domain" - "nex/backend/internal/protocol/anthropic" - "nex/backend/internal/protocol/openai" - "nex/backend/internal/provider" - "nex/backend/internal/service" -) - -// AnthropicHandler Anthropic 协议处理器 -type AnthropicHandler struct { - client provider.ProviderClient - routingService service.RoutingService - statsService service.StatsService -} - -// NewAnthropicHandler 创建 Anthropic 处理器 -func NewAnthropicHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *AnthropicHandler { - return &AnthropicHandler{ - client: client, - routingService: routingService, - statsService: statsService, - } -} - -// HandleMessages 处理 Messages 请求 -func (h *AnthropicHandler) HandleMessages(c *gin.Context) { - var req anthropic.MessagesRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{ - Type: "error", - Error: anthropic.ErrorDetail{ - Type: "invalid_request_error", - Message: "无效的请求格式: " + err.Error(), - }, - }) - return - } - - // 请求验证 - if validationErrors := anthropic.ValidateRequest(&req); validationErrors != nil { - errMsg := formatValidationErrors(validationErrors) - c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{ - Type: "error", - Error: anthropic.ErrorDetail{ - Type: "invalid_request_error", - Message: errMsg, - }, - }) - return - } - - if err := h.checkMultimodalContent(&req); err != nil { - c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{ - Type: "error", - Error: anthropic.ErrorDetail{ - Type: "invalid_request_error", - Message: err.Error(), - }, - }) - return - } - - openaiReq, err := anthropic.ConvertRequest(&req) - if err != nil { - c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{ - Type: "error", - Error: anthropic.ErrorDetail{ - Type: "invalid_request_error", - Message: "请求转换失败: " + err.Error(), - }, - }) - return - } - - routeResult, err := h.routingService.Route(openaiReq.Model) - if err != nil { - h.handleError(c, err) - return - } - - if req.Stream { - h.handleStreamRequest(c, openaiReq, routeResult) - } else { - h.handleNonStreamRequest(c, openaiReq, routeResult) - } -} - -func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) { - openaiResp, err := h.client.SendRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL) - if err != nil { - c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{ - Type: "error", - Error: anthropic.ErrorDetail{ - Type: "api_error", - Message: "供应商请求失败: " + err.Error(), - }, - }) - return - } - - anthropicResp, err := anthropic.ConvertResponse(openaiResp) - if err != nil { - c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{ - Type: "error", - Error: anthropic.ErrorDetail{ - Type: "api_error", - Message: "响应转换失败: " + err.Error(), - }, - }) - return - } - - go func() { - _ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model) - }() - - c.JSON(http.StatusOK, anthropicResp) -} - -func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) { - eventChan, err := h.client.SendStreamRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL) - if err != nil { - c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{ - Type: "error", - Error: anthropic.ErrorDetail{ - Type: "api_error", - Message: "供应商请求失败: " + err.Error(), - }, - }) - return - } - - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - - writer := bufio.NewWriter(c.Writer) - - converter := anthropic.NewStreamConverter( - fmt.Sprintf("msg_%s", routeResult.Provider.ID), - openaiReq.Model, - ) - - for event := range eventChan { - if event.Error != nil { - break - } - - if event.Done { - break - } - - chunk, err := openai.NewAdapter().ParseStreamChunk(event.Data) - if err != nil { - continue - } - - anthropicEvents, err := converter.ConvertChunk(chunk) - if err != nil { - continue - } - - for _, ae := range anthropicEvents { - eventStr, err := anthropic.SerializeEvent(ae) - if err != nil { - continue - } - writer.WriteString(eventStr) - writer.Flush() - } - } - - go func() { - _ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model) - }() -} - -func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest) error { - for _, msg := range req.Messages { - for _, block := range msg.Content { - if block.Type == "image" { - return fmt.Errorf("MVP 不支持多模态内容(图片)") - } - } - } - return nil -} - -func (h *AnthropicHandler) handleError(c *gin.Context, err error) { - if appErr, ok := appErrors.AsAppError(err); ok { - c.JSON(appErr.HTTPStatus, anthropic.ErrorResponse{ - Type: "error", - Error: anthropic.ErrorDetail{ - Type: "not_found_error", - Message: appErr.Message, - }, - }) - return - } - c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{ - Type: "error", - Error: anthropic.ErrorDetail{ - Type: "internal_error", - Message: "内部错误: " + err.Error(), - }, - }) -} diff --git a/backend/internal/handler/handler_test.go b/backend/internal/handler/handler_test.go index 8de985b..e15006f 100644 --- a/backend/internal/handler/handler_test.go +++ b/backend/internal/handler/handler_test.go @@ -4,7 +4,9 @@ import ( "bytes" "context" "encoding/json" + "fmt" "net/http/httptest" + "strings" "testing" "time" @@ -13,7 +15,6 @@ import ( "github.com/stretchr/testify/require" "nex/backend/internal/domain" - "nex/backend/internal/protocol/openai" "nex/backend/internal/provider" appErrors "nex/backend/pkg/errors" ) @@ -34,8 +35,8 @@ func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error } type mockStatsService struct { - err error - stats []domain.UsageStats + err error + stats []domain.UsageStats aggrResult []map[string]interface{} } @@ -84,61 +85,14 @@ func (m *mockModelService) Update(id string, updates map[string]interface{}) err func (m *mockModelService) Delete(id string) error { return m.err } type mockProviderClient struct { - resp *openai.ChatCompletionResponse - eventChan chan provider.StreamEvent - err error + err error } -func (m *mockProviderClient) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) { - return m.resp, m.err +func (m *mockProviderClient) Send(ctx context.Context, spec interface{}) (interface{}, error) { + return nil, m.err } -func (m *mockProviderClient) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan provider.StreamEvent, error) { - return m.eventChan, m.err -} - -// ============ OpenAI Handler 测试 ============ - -func TestOpenAIHandler_HandleChatCompletions_InvalidJSON(t *testing.T) { - h := NewOpenAIHandler(nil, nil, nil) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte("invalid"))) - - h.HandleChatCompletions(c) - assert.Equal(t, 400, w.Code) -} - -func TestOpenAIHandler_HandleChatCompletions_ValidationError(t *testing.T) { - h := NewOpenAIHandler(nil, nil, nil) - - // 缺少 model 字段 - body, _ := json.Marshal(map[string]interface{}{ - "messages": []map[string]string{{"role": "user", "content": "hi"}}, - }) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body)) - c.Request.Header.Set("Content-Type", "application/json") - - h.HandleChatCompletions(c) - assert.Equal(t, 400, w.Code) -} - -func TestOpenAIHandler_HandleChatCompletions_RouteError(t *testing.T) { - routingSvc := &mockRoutingService{err: appErrors.ErrModelNotFound} - h := NewOpenAIHandler(nil, routingSvc, nil) - - body, _ := json.Marshal(map[string]interface{}{ - "model": "nonexistent", - "messages": []map[string]string{{"role": "user", "content": "hi"}}, - }) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body)) - c.Request.Header.Set("Content-Type", "application/json") - - h.HandleChatCompletions(c) - assert.Equal(t, 404, w.Code) +func (m *mockProviderClient) SendStream(ctx context.Context, spec interface{}) (<-chan provider.StreamEvent, error) { + return nil, m.err } // ============ Provider Handler 测试 ============ @@ -283,8 +237,16 @@ func TestFormatValidationErrors(t *testing.T) { "model": "模型名称不能为空", "messages": "消息列表不能为空", } - result := formatValidationErrors(errs) + result := formatMapErrors(errs) require.Contains(t, result, "请求验证失败") require.Contains(t, result, "model") require.Contains(t, result, "messages") } + +func formatMapErrors(errs map[string]string) string { + parts := make([]string, 0, len(errs)) + for field, msg := range errs { + parts = append(parts, fmt.Sprintf("%s: %s", field, msg)) + } + return "请求验证失败: " + strings.Join(parts, "; ") +} diff --git a/backend/internal/handler/openai_handler.go b/backend/internal/handler/openai_handler.go deleted file mode 100644 index a15fc0a..0000000 --- a/backend/internal/handler/openai_handler.go +++ /dev/null @@ -1,157 +0,0 @@ -package handler - -import ( - "bufio" - "fmt" - "net/http" - "strings" - - "github.com/gin-gonic/gin" - - appErrors "nex/backend/pkg/errors" - - "nex/backend/internal/domain" - "nex/backend/internal/protocol/openai" - "nex/backend/internal/provider" - "nex/backend/internal/service" -) - -// OpenAIHandler OpenAI 协议处理器 -type OpenAIHandler struct { - client provider.ProviderClient - routingService service.RoutingService - statsService service.StatsService -} - -// NewOpenAIHandler 创建 OpenAI 处理器 -func NewOpenAIHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *OpenAIHandler { - return &OpenAIHandler{ - client: client, - routingService: routingService, - statsService: statsService, - } -} - -// HandleChatCompletions 处理 Chat Completions 请求 -func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) { - var req openai.ChatCompletionRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, openai.ErrorResponse{ - Error: openai.ErrorDetail{ - Message: "无效的请求格式: " + err.Error(), - Type: "invalid_request_error", - }, - }) - return - } - - // 请求验证 - if validationErrors := openai.ValidateRequest(&req); validationErrors != nil { - c.JSON(http.StatusBadRequest, openai.ErrorResponse{ - Error: openai.ErrorDetail{ - Message: formatValidationErrors(validationErrors), - Type: "invalid_request_error", - }, - }) - return - } - - routeResult, err := h.routingService.Route(req.Model) - if err != nil { - h.handleError(c, err) - return - } - - if req.Stream { - h.handleStreamRequest(c, &req, routeResult) - } else { - h.handleNonStreamRequest(c, &req, routeResult) - } -} - -func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) { - resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL) - if err != nil { - c.JSON(http.StatusInternalServerError, openai.ErrorResponse{ - Error: openai.ErrorDetail{ - Message: "供应商请求失败: " + err.Error(), - Type: "api_error", - }, - }) - return - } - - go func() { - _ = h.statsService.Record(routeResult.Provider.ID, req.Model) - }() - - c.JSON(http.StatusOK, resp) -} - -func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) { - eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL) - if err != nil { - c.JSON(http.StatusInternalServerError, openai.ErrorResponse{ - Error: openai.ErrorDetail{ - Message: "供应商请求失败: " + err.Error(), - Type: "api_error", - }, - }) - return - } - - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - - writer := bufio.NewWriter(c.Writer) - - for event := range eventChan { - if event.Error != nil { - break - } - - if event.Done { - writer.WriteString("data: [DONE]\n\n") - writer.Flush() - break - } - - writer.WriteString("data: ") - writer.Write(event.Data) - writer.WriteString("\n\n") - writer.Flush() - } - - go func() { - _ = h.statsService.Record(routeResult.Provider.ID, req.Model) - }() -} - -func (h *OpenAIHandler) handleError(c *gin.Context, err error) { - if appErr, ok := appErrors.AsAppError(err); ok { - c.JSON(appErr.HTTPStatus, openai.ErrorResponse{ - Error: openai.ErrorDetail{ - Message: appErr.Message, - Type: "invalid_request_error", - Code: appErr.Code, - }, - }) - return - } - c.JSON(http.StatusInternalServerError, openai.ErrorResponse{ - Error: openai.ErrorDetail{ - Message: "内部错误: " + err.Error(), - Type: "internal_error", - }, - }) -} - -// formatValidationErrors 将验证错误 map 格式化为字符串 -func formatValidationErrors(errors map[string]string) string { - parts := make([]string, 0, len(errors)) - for field, msg := range errors { - parts = append(parts, fmt.Sprintf("%s: %s", field, msg)) - } - return "请求验证失败: " + strings.Join(parts, "; ") -} diff --git a/backend/internal/handler/provider_handler.go b/backend/internal/handler/provider_handler.go index 313252b..e5cd4e2 100644 --- a/backend/internal/handler/provider_handler.go +++ b/backend/internal/handler/provider_handler.go @@ -26,10 +26,11 @@ func NewProviderHandler(providerService service.ProviderService) *ProviderHandle // CreateProvider 创建供应商 func (h *ProviderHandler) CreateProvider(c *gin.Context) { var req struct { - ID string `json:"id" binding:"required"` - Name string `json:"name" binding:"required"` - APIKey string `json:"api_key" binding:"required"` - BaseURL string `json:"base_url" binding:"required"` + ID string `json:"id" binding:"required"` + Name string `json:"name" binding:"required"` + APIKey string `json:"api_key" binding:"required"` + BaseURL string `json:"base_url" binding:"required"` + Protocol string `json:"protocol"` } if err := c.ShouldBindJSON(&req); err != nil { @@ -39,11 +40,17 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) { return } + protocol := req.Protocol + if protocol == "" { + protocol = "openai" + } + provider := &domain.Provider{ - ID: req.ID, - Name: req.Name, - APIKey: req.APIKey, - BaseURL: req.BaseURL, + ID: req.ID, + Name: req.Name, + APIKey: req.APIKey, + BaseURL: req.BaseURL, + Protocol: protocol, } err := h.providerService.Create(provider) diff --git a/backend/internal/handler/proxy_handler.go b/backend/internal/handler/proxy_handler.go new file mode 100644 index 0000000..0cca484 --- /dev/null +++ b/backend/internal/handler/proxy_handler.go @@ -0,0 +1,371 @@ +package handler + +import ( + "bufio" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" + + "nex/backend/internal/conversion" + "nex/backend/internal/domain" + "nex/backend/internal/provider" + "nex/backend/internal/service" +) + +// ProxyHandler 统一代理处理器 +type ProxyHandler struct { + engine *conversion.ConversionEngine + client provider.ProviderClient + routingService service.RoutingService + providerService service.ProviderService + statsService service.StatsService + logger *zap.Logger +} + +// NewProxyHandler 创建统一代理处理器 +func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService) *ProxyHandler { + return &ProxyHandler{ + engine: engine, + client: client, + routingService: routingService, + providerService: providerService, + statsService: statsService, + logger: zap.L(), + } +} + +// HandleProxy 处理代理请求 +func (h *ProxyHandler) HandleProxy(c *gin.Context) { + // 从 URL 提取 clientProtocol: /{protocol}/v1/... + clientProtocol := c.Param("protocol") + if clientProtocol == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"}) + return + } + + // 原始路径: /v1/{path} + path := c.Param("path") + if strings.HasPrefix(path, "/") { + path = path[1:] + } + nativePath := "/v1/" + path + + // 读取请求体 + body, err := io.ReadAll(c.Request.Body) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"}) + return + } + + // 解析 model 名称(从 JSON body 中提取,GET 请求无 body) + modelName := "" + if len(body) > 0 { + modelName = extractModelName(body) + } + + // 构建输入 HTTPRequestSpec + inSpec := conversion.HTTPRequestSpec{ + URL: nativePath, + Method: c.Request.Method, + Headers: extractHeaders(c), + Body: body, + } + + // 路由 + routeResult, err := h.routingService.Route(modelName) + if err != nil { + // GET 请求或无法提取 model 时,直接转发到上游 + if len(body) == 0 || modelName == "" { + h.forwardPassthrough(c, inSpec, clientProtocol) + return + } + h.writeError(c, err, clientProtocol) + return + } + + // 确定 providerProtocol + providerProtocol := routeResult.Provider.Protocol + if providerProtocol == "" { + providerProtocol = "openai" + } + + // 构建 TargetProvider + targetProvider := conversion.NewTargetProvider( + routeResult.Provider.BaseURL, + routeResult.Provider.APIKey, + routeResult.Model.ModelName, + ) + + // 判断是否流式 + isStream := h.isStreamRequest(body, clientProtocol, nativePath) + + if isStream { + h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult) + } else { + h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult) + } +} + +// handleNonStream 处理非流式请求 +func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) { + // 转换请求 + outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) + if err != nil { + h.logger.Error("转换请求失败", zap.String("error", err.Error())) + h.writeConversionError(c, err, clientProtocol) + return + } + + // 发送请求 + resp, err := h.client.Send(c.Request.Context(), *outSpec) + if err != nil { + h.logger.Error("发送请求失败", zap.String("error", err.Error())) + h.writeConversionError(c, err, clientProtocol) + return + } + + // 转换响应 + interfaceType, _ := h.engine.DetectInterfaceType(inSpec.URL, clientProtocol) + convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, interfaceType) + if err != nil { + h.logger.Error("转换响应失败", zap.String("error", err.Error())) + h.writeConversionError(c, err, clientProtocol) + return + } + + // 设置响应头 + for k, v := range convertedResp.Headers { + c.Header(k, v) + } + if c.GetHeader("Content-Type") == "" { + c.Header("Content-Type", "application/json") + } + + c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body) + + go func() { + _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) + }() +} + +// handleStream 处理流式请求 +func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) { + // 转换请求 + outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) + if err != nil { + h.writeConversionError(c, err, clientProtocol) + return + } + + // 创建流式转换器 + streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol) + if err != nil { + h.writeConversionError(c, err, clientProtocol) + return + } + + // 发送流式请求 + eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec) + if err != nil { + h.writeConversionError(c, err, clientProtocol) + return + } + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + + writer := bufio.NewWriter(c.Writer) + + for event := range eventChan { + if event.Error != nil { + h.logger.Error("流读取错误", zap.String("error", event.Error.Error())) + break + } + if event.Done { + // flush 转换器 + chunks := streamConverter.Flush() + for _, chunk := range chunks { + writer.Write(chunk) + writer.Flush() + } + break + } + + chunks := streamConverter.ProcessChunk(event.Data) + for _, chunk := range chunks { + writer.Write(chunk) + writer.Flush() + } + } + + go func() { + _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) + }() +} + +// isStreamRequest 判断是否流式请求 +func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool { + ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol) + if ifaceType != conversion.InterfaceTypeChat { + return false + } + for i, b := range body { + if b == '"' && i+8 <= len(body) { + if string(body[i:i+8]) == `"stream"` { + for j := i + 8; j < len(body) && j < i+20; j++ { + if body[j] == 't' && j+3 < len(body) && string(body[j:j+4]) == "true" { + return true + } + } + } + } + } + return false +} + +// writeConversionError 写入转换错误 +func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) { + if convErr, ok := err.(*conversion.ConversionError); ok { + body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol) + c.Data(statusCode, "application/json", body) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) +} + +// writeError 写入路由错误 +func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) +} + +// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求) +func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) { + registry := h.engine.GetRegistry() + adapter, err := registry.Get(clientProtocol) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol}) + return + } + + providers, err := h.providerService.List() + if err != nil || len(providers) == 0 { + h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL)) + c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"}) + return + } + + p := providers[0] + providerProtocol := p.Protocol + if providerProtocol == "" { + providerProtocol = "openai" + } + + ifaceType := adapter.DetectInterfaceType(inSpec.URL) + + targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "") + + var outSpec *conversion.HTTPRequestSpec + if clientProtocol == providerProtocol { + upstreamURL := p.BaseURL + inSpec.URL + headers := adapter.BuildHeaders(targetProvider) + if _, ok := headers["Content-Type"]; !ok { + headers["Content-Type"] = "application/json" + } + outSpec = &conversion.HTTPRequestSpec{ + URL: upstreamURL, + Method: inSpec.Method, + Headers: headers, + Body: inSpec.Body, + } + } else { + outSpec, err = h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) + if err != nil { + h.writeConversionError(c, err, clientProtocol) + return + } + } + + resp, err := h.client.Send(c.Request.Context(), *outSpec) + if err != nil { + h.writeConversionError(c, err, clientProtocol) + return + } + + convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType) + if err != nil { + h.writeConversionError(c, err, clientProtocol) + return + } + + for k, v := range convertedResp.Headers { + c.Header(k, v) + } + if c.GetHeader("Content-Type") == "" { + c.Header("Content-Type", "application/json") + } + c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body) +} + +// extractModelName 从 JSON body 中提取 model +func extractModelName(body []byte) string { + inQuote := false + escaped := false + keyStart := -1 + keyEnd := -1 + lookingForKey := true + lookingForValue := false + valueStart := -1 + + for i := 0; i < len(body); i++ { + b := body[i] + if escaped { + escaped = false + continue + } + if b == '\\' { + escaped = true + continue + } + if b == '"' { + if !inQuote { + inQuote = true + if lookingForKey { + keyStart = i + 1 + } + if lookingForValue { + valueStart = i + 1 + } + } else { + inQuote = false + if lookingForKey && keyStart >= 0 { + keyEnd = i + if string(body[keyStart:keyEnd]) == "model" { + lookingForKey = false + lookingForValue = true + } + } else if lookingForValue && valueStart >= 0 { + return string(body[valueStart:i]) + } + } + } + if !inQuote && lookingForValue && b == ':' { + // 等待值开始 + } + } + return "" +} + +// extractHeaders 从 Gin context 提取请求头 +func extractHeaders(c *gin.Context) map[string]string { + headers := make(map[string]string) + for k, vs := range c.Request.Header { + if len(vs) > 0 { + headers[k] = vs[0] + } + } + return headers +} diff --git a/backend/internal/protocol/anthropic/converter.go b/backend/internal/protocol/anthropic/converter.go deleted file mode 100644 index 8b9675b..0000000 --- a/backend/internal/protocol/anthropic/converter.go +++ /dev/null @@ -1,234 +0,0 @@ -package anthropic - -import ( - "encoding/json" - "fmt" - - "nex/backend/internal/protocol/openai" -) - -// ConvertRequest 将 Anthropic 请求转换为 OpenAI 请求 -func ConvertRequest(anthropicReq *MessagesRequest) (*openai.ChatCompletionRequest, error) { - openaiReq := &openai.ChatCompletionRequest{ - Model: anthropicReq.Model, - Temperature: anthropicReq.Temperature, - TopP: anthropicReq.TopP, - Stream: anthropicReq.Stream, - } - - // 处理 max_tokens(Anthropic 要求必须有,默认 4096) - if anthropicReq.MaxTokens > 0 { - openaiReq.MaxTokens = &anthropicReq.MaxTokens - } else { - defaultMax := 4096 - openaiReq.MaxTokens = &defaultMax - } - - // 处理 stop_sequences - if len(anthropicReq.StopSequences) > 0 { - openaiReq.Stop = anthropicReq.StopSequences - } - - // 转换 system 消息 - messages := make([]openai.Message, 0) - if anthropicReq.System != "" { - messages = append(messages, openai.Message{ - Role: "system", - Content: anthropicReq.System, - }) - } - - // 转换 messages - for _, msg := range anthropicReq.Messages { - openaiMsg, err := convertMessage(msg) - if err != nil { - return nil, err - } - messages = append(messages, openaiMsg...) - } - openaiReq.Messages = messages - - // 转换 tools - if len(anthropicReq.Tools) > 0 { - openaiReq.Tools = make([]openai.Tool, len(anthropicReq.Tools)) - for i, tool := range anthropicReq.Tools { - openaiReq.Tools[i] = openai.Tool{ - Type: "function", - Function: openai.FunctionDefinition{ - Name: tool.Name, - Description: tool.Description, - Parameters: tool.InputSchema, - }, - } - } - } - - // 转换 tool_choice - if anthropicReq.ToolChoice != nil { - toolChoice, err := convertToolChoice(anthropicReq.ToolChoice) - if err != nil { - return nil, err - } - openaiReq.ToolChoice = toolChoice - } - - return openaiReq, nil -} - -// ConvertResponse 将 OpenAI 响应转换为 Anthropic 响应 -func ConvertResponse(openaiResp *openai.ChatCompletionResponse) (*MessagesResponse, error) { - anthropicResp := &MessagesResponse{ - ID: openaiResp.ID, - Type: "message", - Role: "assistant", - Model: openaiResp.Model, - Usage: Usage{ - InputTokens: openaiResp.Usage.PromptTokens, - OutputTokens: openaiResp.Usage.CompletionTokens, - }, - } - - // 转换 content - if len(openaiResp.Choices) > 0 { - choice := openaiResp.Choices[0] - content := make([]ContentBlock, 0) - - if choice.Message != nil { - // 文本内容 - if choice.Message.Content != "" { - if str, ok := choice.Message.Content.(string); ok && str != "" { - content = append(content, ContentBlock{ - Type: "text", - Text: str, - }) - } - } - - // Tool calls - if len(choice.Message.ToolCalls) > 0 { - for _, tc := range choice.Message.ToolCalls { - // 解析 arguments JSON - var input interface{} - if err := json.Unmarshal([]byte(tc.Function.Arguments), &input); err != nil { - return nil, fmt.Errorf("解析 tool_call arguments 失败: %w", err) - } - - content = append(content, ContentBlock{ - Type: "tool_use", - ID: tc.ID, - Name: tc.Function.Name, - Input: input, - }) - } - } - } - - anthropicResp.Content = content - - // 转换 finish_reason - switch choice.FinishReason { - case "stop": - anthropicResp.StopReason = "end_turn" - case "tool_calls": - anthropicResp.StopReason = "tool_use" - case "length": - anthropicResp.StopReason = "max_tokens" - } - } - - return anthropicResp, nil -} - -// convertMessage 转换单条消息 -func convertMessage(msg AnthropicMessage) ([]openai.Message, error) { - var messages []openai.Message - - // 处理 content - for _, block := range msg.Content { - switch block.Type { - case "text": - // 文本内容 - messages = append(messages, openai.Message{ - Role: msg.Role, - Content: block.Text, - }) - - case "tool_result": - // 工具结果 - content := "" - if str, ok := block.Content.(string); ok { - content = str - } else { - // 如果是数组或其他类型,序列化为 JSON - bytes, err := json.Marshal(block.Content) - if err != nil { - return nil, fmt.Errorf("序列化 tool_result 内容失败: %w", err) - } - content = string(bytes) - } - - messages = append(messages, openai.Message{ - Role: "tool", - Content: content, - ToolCallID: block.ToolUseID, - }) - - case "image": - // MVP 不支持多模态 - return nil, fmt.Errorf("MVP 不支持多模态内容(图片)") - - default: - return nil, fmt.Errorf("未知的内容块类型: %s", block.Type) - } - } - - // 如果没有 content,创建空消息(不应该发生) - if len(messages) == 0 { - messages = append(messages, openai.Message{ - Role: msg.Role, - Content: "", - }) - } - - return messages, nil -} - -// convertToolChoice 转换工具选择 -func convertToolChoice(choice interface{}) (interface{}, error) { - // 如果是字符串 - if str, ok := choice.(string); ok { - // "auto" 或 "any" 都映射为 "auto" - if str == "auto" || str == "any" { - return "auto", nil - } - return nil, fmt.Errorf("无效的 tool_choice 字符串: %s", str) - } - - // 如果是对象 - if obj, ok := choice.(map[string]interface{}); ok { - choiceType, ok := obj["type"].(string) - if !ok { - return nil, fmt.Errorf("tool_choice 对象缺少 type 字段") - } - - switch choiceType { - case "auto", "any": - return "auto", nil - case "tool": - name, ok := obj["name"].(string) - if !ok { - return nil, fmt.Errorf("tool_choice type=tool 缺少 name 字段") - } - return map[string]interface{}{ - "type": "function", - "function": map[string]string{ - "name": name, - }, - }, nil - default: - return nil, fmt.Errorf("无效的 tool_choice type: %s", choiceType) - } - } - - return nil, fmt.Errorf("tool_choice 格式无效") -} diff --git a/backend/internal/protocol/anthropic/converter_test.go b/backend/internal/protocol/anthropic/converter_test.go deleted file mode 100644 index 683d9d6..0000000 --- a/backend/internal/protocol/anthropic/converter_test.go +++ /dev/null @@ -1,270 +0,0 @@ -package anthropic - -import ( - "encoding/json" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "nex/backend/internal/protocol/openai" -) - -func TestConvertRequest_Basic(t *testing.T) { - temp := 0.7 - req := &MessagesRequest{ - Model: "claude-3-opus", - MaxTokens: 1024, - Temperature: &temp, - Messages: []AnthropicMessage{ - { - Role: "user", - Content: []ContentBlock{ - {Type: "text", Text: "Hello"}, - }, - }, - }, - } - - result, err := ConvertRequest(req) - require.NoError(t, err) - assert.Equal(t, "claude-3-opus", result.Model) - assert.Equal(t, 1024, *result.MaxTokens) - assert.Equal(t, &temp, result.Temperature) - require.Len(t, result.Messages, 1) - assert.Equal(t, "user", result.Messages[0].Role) - assert.Equal(t, "Hello", result.Messages[0].Content) -} - -func TestConvertRequest_WithSystem(t *testing.T) { - req := &MessagesRequest{ - Model: "claude-3-opus", - MaxTokens: 100, - System: "You are a helpful assistant.", - Messages: []AnthropicMessage{ - { - Role: "user", - Content: []ContentBlock{{Type: "text", Text: "Hi"}}, - }, - }, - } - - result, err := ConvertRequest(req) - require.NoError(t, err) - require.Len(t, result.Messages, 2) - assert.Equal(t, "system", result.Messages[0].Role) - assert.Equal(t, "You are a helpful assistant.", result.Messages[0].Content) - assert.Equal(t, "user", result.Messages[1].Role) -} - -func TestConvertRequest_DefaultMaxTokens(t *testing.T) { - req := &MessagesRequest{ - Model: "claude-3-opus", - MaxTokens: 0, // 未设置 - Messages: []AnthropicMessage{ - {Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}}, - }, - } - - result, err := ConvertRequest(req) - require.NoError(t, err) - assert.Equal(t, 4096, *result.MaxTokens) -} - -func TestConvertRequest_WithTools(t *testing.T) { - req := &MessagesRequest{ - Model: "claude-3-opus", - MaxTokens: 100, - Messages: []AnthropicMessage{ - {Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}}, - }, - Tools: []AnthropicTool{ - { - Name: "get_weather", - Description: "Get weather info", - InputSchema: map[string]interface{}{"type": "object"}, - }, - }, - } - - result, err := ConvertRequest(req) - require.NoError(t, err) - require.Len(t, result.Tools, 1) - assert.Equal(t, "function", result.Tools[0].Type) - assert.Equal(t, "get_weather", result.Tools[0].Function.Name) -} - -func TestConvertRequest_WithStopSequences(t *testing.T) { - req := &MessagesRequest{ - Model: "claude-3-opus", - MaxTokens: 100, - StopSequences: []string{"STOP", "END"}, - Messages: []AnthropicMessage{ - {Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}}, - }, - } - - result, err := ConvertRequest(req) - require.NoError(t, err) - assert.Equal(t, []string{"STOP", "END"}, result.Stop) -} - -func TestConvertRequest_ToolResult(t *testing.T) { - req := &MessagesRequest{ - Model: "claude-3-opus", - MaxTokens: 100, - Messages: []AnthropicMessage{ - { - Role: "user", - Content: []ContentBlock{ - { - Type: "tool_result", - ToolUseID: "tool_123", - Content: "result data", - }, - }, - }, - }, - } - - result, err := ConvertRequest(req) - require.NoError(t, err) - require.Len(t, result.Messages, 1) - assert.Equal(t, "tool", result.Messages[0].Role) - assert.Equal(t, "tool_123", result.Messages[0].ToolCallID) - assert.Equal(t, "result data", result.Messages[0].Content) -} - -func TestConvertResponse(t *testing.T) { - resp := &openai.ChatCompletionResponse{ - ID: "chatcmpl-123", - Model: "gpt-4", - Choices: []openai.Choice{ - { - Index: 0, - Message: &openai.Message{Role: "assistant", Content: "Hello!"}, - FinishReason: "stop", - }, - }, - Usage: openai.Usage{PromptTokens: 10, CompletionTokens: 5}, - } - - result, err := ConvertResponse(resp) - require.NoError(t, err) - assert.Equal(t, "chatcmpl-123", result.ID) - assert.Equal(t, "message", result.Type) - assert.Equal(t, "assistant", result.Role) - assert.Equal(t, "end_turn", result.StopReason) - require.Len(t, result.Content, 1) - assert.Equal(t, "text", result.Content[0].Type) - assert.Equal(t, "Hello!", result.Content[0].Text) - assert.Equal(t, 10, result.Usage.InputTokens) - assert.Equal(t, 5, result.Usage.OutputTokens) -} - -func TestConvertResponse_ToolCalls(t *testing.T) { - args, _ := json.Marshal(map[string]interface{}{"city": "Beijing"}) - resp := &openai.ChatCompletionResponse{ - ID: "chatcmpl-456", - Model: "gpt-4", - Choices: []openai.Choice{ - { - Index: 0, - Message: &openai.Message{ - Role: "assistant", - ToolCalls: []openai.ToolCall{ - { - ID: "call_123", - Type: "function", - Function: openai.FunctionCall{ - Name: "get_weather", - Arguments: string(args), - }, - }, - }, - }, - FinishReason: "tool_calls", - }, - }, - Usage: openai.Usage{}, - } - - result, err := ConvertResponse(resp) - require.NoError(t, err) - assert.Equal(t, "tool_use", result.StopReason) - require.Len(t, result.Content, 1) - assert.Equal(t, "tool_use", result.Content[0].Type) - assert.Equal(t, "call_123", result.Content[0].ID) - assert.Equal(t, "get_weather", result.Content[0].Name) -} - -func TestConvertToolChoice_String(t *testing.T) { - tests := []struct { - name string - input interface{} - wantErr bool - check func(interface{}) - }{ - {"auto字符串", "auto", false, func(r interface{}) { assert.Equal(t, "auto", r) }}, - {"any字符串", "any", false, func(r interface{}) { assert.Equal(t, "auto", r) }}, - {"无效字符串", "invalid", true, nil}, - {"tool对象", map[string]interface{}{"type": "tool", "name": "my_func"}, false, - func(r interface{}) { - m := r.(map[string]interface{}) - assert.Equal(t, "function", m["type"]) - }}, - {"缺少name的tool对象", map[string]interface{}{"type": "tool"}, true, nil}, - {"缺少type的对象", map[string]interface{}{"name": "func"}, true, nil}, - {"无效类型", 42, true, nil}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := convertToolChoice(tt.input) - if tt.wantErr { - assert.Error(t, err) - } else { - require.NoError(t, err) - tt.check(result) - } - }) - } -} - -func TestValidateRequest(t *testing.T) { - t.Run("有效请求", func(t *testing.T) { - req := &MessagesRequest{ - Model: "claude-3-opus", - MaxTokens: 100, - Messages: []AnthropicMessage{ - {Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}}, - }, - } - errs := ValidateRequest(req) - assert.Nil(t, errs) - }) - - t.Run("缺少模型", func(t *testing.T) { - req := &MessagesRequest{ - MaxTokens: 100, - Messages: []AnthropicMessage{ - {Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}}, - }, - } - errs := ValidateRequest(req) - assert.NotNil(t, errs) - assert.Contains(t, errs["model"], "不能为空") - }) - - t.Run("MaxTokens为0", func(t *testing.T) { - req := &MessagesRequest{ - Model: "claude-3-opus", - MaxTokens: 0, - Messages: []AnthropicMessage{ - {Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}}, - }, - } - errs := ValidateRequest(req) - assert.NotNil(t, errs) - }) -} diff --git a/backend/internal/protocol/anthropic/stream_converter.go b/backend/internal/protocol/anthropic/stream_converter.go deleted file mode 100644 index f6e6f5b..0000000 --- a/backend/internal/protocol/anthropic/stream_converter.go +++ /dev/null @@ -1,164 +0,0 @@ -package anthropic - -import ( - "encoding/json" - "fmt" - - "nex/backend/internal/protocol/openai" -) - -// StreamConverter 流式转换器 -type StreamConverter struct { - messageID string - model string - index int // 当前 content block index - toolCallArgs map[int]string // 缓存每个 tool_call 的 arguments - sentStart bool // 是否已发送 message_start - sentBlockStart map[int]bool // 每个 index 是否已发送 content_block_start -} - -// NewStreamConverter 创建流式转换器 -func NewStreamConverter(messageID, model string) *StreamConverter { - return &StreamConverter{ - messageID: messageID, - model: model, - index: 0, - toolCallArgs: make(map[int]string), - sentStart: false, - sentBlockStart: make(map[int]bool), - } -} - -// ConvertChunk 转换 OpenAI 流块为 Anthropic 事件 -func (c *StreamConverter) ConvertChunk(chunk *openai.StreamChunk) ([]StreamEvent, error) { - var events []StreamEvent - - // 发送 message_start(仅一次) - if !c.sentStart { - events = append(events, StreamEvent{ - Type: "message_start", - Message: &MessagesResponse{ - ID: c.messageID, - Type: "message", - Role: "assistant", - Model: c.model, - Content: []ContentBlock{}, - Usage: Usage{ - InputTokens: 0, - OutputTokens: 0, - }, - }, - }) - c.sentStart = true - } - - // 处理每个 choice - for _, choice := range chunk.Choices { - // 处理 content delta - if choice.Delta.Content != "" { - // 发送 content_block_start(如果还没发送) - if !c.sentBlockStart[c.index] { - events = append(events, StreamEvent{ - Type: "content_block_start", - Index: c.index, - ContentBlock: &ContentBlock{ - Type: "text", - }, - }) - c.sentBlockStart[c.index] = true - } - - // 发送 text delta - events = append(events, StreamEvent{ - Type: "content_block_delta", - Index: c.index, - Delta: &Delta{ - Type: "text_delta", - Text: choice.Delta.Content, - }, - }) - } - - // 处理 tool_calls delta - if len(choice.Delta.ToolCalls) > 0 { - for _, tc := range choice.Delta.ToolCalls { - // 确定 tool_call index - toolIndex := c.index + len(c.toolCallArgs) - - // 发送 content_block_start(如果还没发送) - if !c.sentBlockStart[toolIndex] { - events = append(events, StreamEvent{ - Type: "content_block_start", - Index: toolIndex, - ContentBlock: &ContentBlock{ - Type: "tool_use", - ID: tc.ID, - Name: tc.Function.Name, - }, - }) - c.sentBlockStart[toolIndex] = true - c.toolCallArgs[toolIndex] = "" - } - - // 缓存 arguments - c.toolCallArgs[toolIndex] += tc.Function.Arguments - - // 发送 input delta - events = append(events, StreamEvent{ - Type: "content_block_delta", - Index: toolIndex, - Delta: &Delta{ - Type: "input_json_delta", - Input: tc.Function.Arguments, - }, - }) - } - } - - // 处理 finish_reason - if choice.FinishReason != "" { - // 发送 content_block_stop - for idx := range c.sentBlockStart { - events = append(events, StreamEvent{ - Type: "content_block_stop", - Index: idx, - }) - } - - // 转换 stop_reason - stopReason := "" - switch choice.FinishReason { - case "stop": - stopReason = "end_turn" - case "tool_calls": - stopReason = "tool_use" - case "length": - stopReason = "max_tokens" - } - - // 发送 message_delta - events = append(events, StreamEvent{ - Type: "message_delta", - Delta: &Delta{ - StopReason: stopReason, - }, - }) - - // 发送 message_stop - events = append(events, StreamEvent{ - Type: "message_stop", - }) - } - } - - return events, nil -} - -// SerializeEvent 序列化事件为 SSE 格式 -func SerializeEvent(event StreamEvent) (string, error) { - bytes, err := json.Marshal(event) - if err != nil { - return "", err - } - return fmt.Sprintf("event: %s\ndata: %s\n\n", event.Type, string(bytes)), nil -} diff --git a/backend/internal/protocol/anthropic/stream_converter_test.go b/backend/internal/protocol/anthropic/stream_converter_test.go deleted file mode 100644 index 1608894..0000000 --- a/backend/internal/protocol/anthropic/stream_converter_test.go +++ /dev/null @@ -1,229 +0,0 @@ -package anthropic - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "nex/backend/internal/protocol/openai" -) - -func TestStreamConverter_MessageStart(t *testing.T) { - converter := NewStreamConverter("msg_123", "claude-3-opus") - - chunk := &openai.StreamChunk{ - ID: "chatcmpl-123", - Choices: []openai.StreamChoice{{Index: 0, Delta: openai.Delta{}}}, - } - - events, err := converter.ConvertChunk(chunk) - require.NoError(t, err) - require.NotEmpty(t, events) - - // 第一个事件应该是 message_start - assert.Equal(t, "message_start", events[0].Type) - require.NotNil(t, events[0].Message) - assert.Equal(t, "msg_123", events[0].Message.ID) - assert.Equal(t, "message", events[0].Message.Type) - assert.Equal(t, "assistant", events[0].Message.Role) - assert.Equal(t, "claude-3-opus", events[0].Message.Model) -} - -func TestStreamConverter_TextDelta(t *testing.T) { - converter := NewStreamConverter("msg_123", "claude-3-opus") - - // 先发送一个空块以触发 message_start - chunk1 := &openai.StreamChunk{ - Choices: []openai.StreamChoice{ - {Delta: openai.Delta{Content: "Hello"}}, - }, - } - events1, err := converter.ConvertChunk(chunk1) - require.NoError(t, err) - // 应有 message_start + content_block_start + text delta - assert.GreaterOrEqual(t, len(events1), 3) - - // 第二个文本块不应再发送 message_start 和 content_block_start - chunk2 := &openai.StreamChunk{ - Choices: []openai.StreamChoice{ - {Delta: openai.Delta{Content: " world"}}, - }, - } - events2, err := converter.ConvertChunk(chunk2) - require.NoError(t, err) - // 只有 text delta - assert.Len(t, events2, 1) - assert.Equal(t, "content_block_delta", events2[0].Type) - assert.Equal(t, "text_delta", events2[0].Delta.Type) - assert.Equal(t, " world", events2[0].Delta.Text) -} - -func TestStreamConverter_FinishReason(t *testing.T) { - converter := NewStreamConverter("msg_123", "claude-3-opus") - - chunk := &openai.StreamChunk{ - Choices: []openai.StreamChoice{ - {Delta: openai.Delta{Content: "Hello"}, FinishReason: "stop"}, - }, - } - events, err := converter.ConvertChunk(chunk) - require.NoError(t, err) - - // 查找 message_delta 事件 - var messageDelta *StreamEvent - for _, e := range events { - if e.Type == "message_delta" { - messageDelta = &e - break - } - } - require.NotNil(t, messageDelta) - assert.Equal(t, "end_turn", messageDelta.Delta.StopReason) - - // 查找 message_stop 事件 - var messageStop *StreamEvent - for _, e := range events { - if e.Type == "message_stop" { - messageStop = &e - break - } - } - assert.NotNil(t, messageStop) -} - -func TestStreamConverter_FinishReasonToolCalls(t *testing.T) { - converter := NewStreamConverter("msg_123", "claude-3-opus") - - chunk := &openai.StreamChunk{ - Choices: []openai.StreamChoice{ - {Delta: openai.Delta{}, FinishReason: "tool_calls"}, - }, - } - events, err := converter.ConvertChunk(chunk) - require.NoError(t, err) - - var messageDelta *StreamEvent - for _, e := range events { - if e.Type == "message_delta" { - messageDelta = &e - break - } - } - require.NotNil(t, messageDelta) - assert.Equal(t, "tool_use", messageDelta.Delta.StopReason) -} - -func TestStreamConverter_FinishReasonLength(t *testing.T) { - converter := NewStreamConverter("msg_123", "claude-3-opus") - - chunk := &openai.StreamChunk{ - Choices: []openai.StreamChoice{ - {Delta: openai.Delta{}, FinishReason: "length"}, - }, - } - events, err := converter.ConvertChunk(chunk) - require.NoError(t, err) - - var messageDelta *StreamEvent - for _, e := range events { - if e.Type == "message_delta" { - messageDelta = &e - break - } - } - require.NotNil(t, messageDelta) - assert.Equal(t, "max_tokens", messageDelta.Delta.StopReason) -} - -func TestStreamConverter_ToolCalls(t *testing.T) { - converter := NewStreamConverter("msg_123", "claude-3-opus") - - chunk := &openai.StreamChunk{ - Choices: []openai.StreamChoice{ - { - Delta: openai.Delta{ - ToolCalls: []openai.ToolCall{ - { - ID: "call_123", - Type: "function", - Function: openai.FunctionCall{ - Name: "get_weather", - Arguments: `{"city": "Beijing"}`, - }, - }, - }, - }, - }, - }, - } - - events, err := converter.ConvertChunk(chunk) - require.NoError(t, err) - - // 应包含 content_block_start (tool_use) + content_block_delta (input_json_delta) - hasBlockStart := false - hasInputDelta := false - for _, e := range events { - if e.Type == "content_block_start" && e.ContentBlock != nil && e.ContentBlock.Type == "tool_use" { - hasBlockStart = true - assert.Equal(t, "call_123", e.ContentBlock.ID) - assert.Equal(t, "get_weather", e.ContentBlock.Name) - } - if e.Type == "content_block_delta" && e.Delta != nil && e.Delta.Type == "input_json_delta" { - hasInputDelta = true - assert.Equal(t, `{"city": "Beijing"}`, e.Delta.Input) - } - } - assert.True(t, hasBlockStart, "应有 tool_use content_block_start") - assert.True(t, hasInputDelta, "应有 input_json_delta") -} - -func TestSerializeEvent(t *testing.T) { - event := StreamEvent{ - Type: "message_start", - Message: &MessagesResponse{ - ID: "msg_123", - Type: "message", - Role: "assistant", - }, - } - - result, err := SerializeEvent(event) - require.NoError(t, err) - assert.Contains(t, result, "event: message_start") - assert.Contains(t, result, "data: ") - assert.Contains(t, result, "msg_123") -} - -func TestSerializeEvent_InvalidJSON(t *testing.T) { - event := StreamEvent{ - Type: "test", - } - // 这个应该能正常序列化 - result, err := SerializeEvent(event) - require.NoError(t, err) - assert.Contains(t, result, "event: test") -} - -func TestContentBlock_ParseInputJSON(t *testing.T) { - t.Run("字符串输入", func(t *testing.T) { - cb := &ContentBlock{Input: `{"key": "value"}`} - result, err := cb.ParseInputJSON() - require.NoError(t, err) - assert.Equal(t, "value", result["key"]) - }) - - t.Run("对象输入", func(t *testing.T) { - cb := &ContentBlock{Input: map[string]interface{}{"key": "value"}} - result, err := cb.ParseInputJSON() - require.NoError(t, err) - assert.Equal(t, "value", result["key"]) - }) - - t.Run("无效类型", func(t *testing.T) { - cb := &ContentBlock{Input: 42} - _, err := cb.ParseInputJSON() - assert.Error(t, err) - }) -} diff --git a/backend/internal/protocol/anthropic/types.go b/backend/internal/protocol/anthropic/types.go deleted file mode 100644 index 54f37bf..0000000 --- a/backend/internal/protocol/anthropic/types.go +++ /dev/null @@ -1,149 +0,0 @@ -package anthropic - -import ( - "encoding/json" - "fmt" - - "github.com/go-playground/validator/v10" - - pkgValidator "nex/backend/pkg/validator" -) - -// MessagesRequest Anthropic Messages API 请求结构 -type MessagesRequest struct { - Model string `json:"model" validate:"required"` - Messages []AnthropicMessage `json:"messages" validate:"required,min=1"` - System string `json:"system,omitempty"` - MaxTokens int `json:"max_tokens" validate:"required,min=1"` - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"top_p,omitempty"` - TopK *int `json:"top_k,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` - Stream bool `json:"stream,omitempty"` - Tools []AnthropicTool `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` // 可以是字符串或对象 - Metadata map[string]interface{} `json:"metadata,omitempty"` -} - -// AnthropicMessage Anthropic 消息结构 -type AnthropicMessage struct { - Role string `json:"role"` - Content []ContentBlock `json:"content"` -} - -// ContentBlock 内容块 -type ContentBlock struct { - Type string `json:"type"` // "text", "image", "tool_use", "tool_result" - Text string `json:"text,omitempty"` - Input interface{} `json:"input,omitempty"` // 用于 tool_use - - // tool_use 字段 - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - - // tool_result 字段 - ToolUseID string `json:"tool_use_id,omitempty"` - Content interface{} `json:"content,omitempty"` // 可以是字符串或数组 - - // 多模态字段(MVP 不支持) - Source interface{} `json:"source,omitempty"` // 用于 image -} - -// AnthropicTool Anthropic 工具定义 -type AnthropicTool struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - InputSchema map[string]interface{} `json:"input_schema"` -} - -// ToolChoice 工具选择 -type ToolChoice struct { - Type string `json:"type"` // "auto", "any", "tool" - Name string `json:"name,omitempty"` // 当 type="tool" 时使用 -} - -// MessagesResponse Anthropic Messages API 响应结构 -type MessagesResponse struct { - ID string `json:"id"` - Type string `json:"type"` // "message" - Role string `json:"role"` // "assistant" - Content []ContentBlock `json:"content"` - Model string `json:"model"` - StopReason string `json:"stop_reason,omitempty"` // "end_turn", "max_tokens", "stop_sequence", "tool_use" - StopSequence string `json:"stop_sequence,omitempty"` - Usage Usage `json:"usage"` -} - -// Usage 使用统计 -type Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` -} - -// StreamEvent 流式事件 -type StreamEvent struct { - Type string `json:"type"` - Message *MessagesResponse `json:"message,omitempty"` // 用于 message_start - Index int `json:"index,omitempty"` // 用于 content_block_* 事件 - ContentBlock *ContentBlock `json:"content_block,omitempty"` // 用于 content_block_start - Delta *Delta `json:"delta,omitempty"` // 用于 content_block_delta -} - -// Delta 增量内容 -type Delta struct { - Type string `json:"type,omitempty"` // "text_delta", "input_json_delta" - Text string `json:"text,omitempty"` - Input string `json:"input,omitempty"` // 用于 tool_use 的部分 JSON - StopReason string `json:"stop_reason,omitempty"` // 用于 message_delta - Usage *Usage `json:"usage,omitempty"` // 用于 message_delta -} - -// ErrorResponse Anthropic 错误响应 -type ErrorResponse struct { - Type string `json:"type"` // "error" - Error ErrorDetail `json:"error"` -} - -// ErrorDetail 错误详情 -type ErrorDetail struct { - Type string `json:"type"` // "invalid_request_error", "authentication_error", etc. - Message string `json:"message"` -} - -// ParseInputJSON 解析 tool_use 的 input(从 JSON 字符串转为 map) -func (cb *ContentBlock) ParseInputJSON() (map[string]interface{}, error) { - if str, ok := cb.Input.(string); ok { - var result map[string]interface{} - err := json.Unmarshal([]byte(str), &result) - return result, err - } - // 如果已经是对象,直接返回 - if obj, ok := cb.Input.(map[string]interface{}); ok { - return obj, nil - } - return nil, fmt.Errorf("invalid input type: expected string or map") -} - -// ValidateRequest 验证 MessagesRequest -func ValidateRequest(req *MessagesRequest) map[string]string { - errs := pkgValidator.Validate(req) - if errs == nil { - return nil - } - - validationErrors := make(map[string]string) - for _, err := range errs.(validator.ValidationErrors) { - field := err.Field() - switch field { - case "Model": - validationErrors["model"] = "模型名称不能为空" - case "Messages": - validationErrors["messages"] = "消息列表不能为空" - case "MaxTokens": - validationErrors["max_tokens"] = "max_tokens 不能为空且必须大于 0" - default: - validationErrors[field] = fmt.Sprintf("字段 %s 验证失败: %s", field, err.Tag()) - } - } - return validationErrors -} diff --git a/backend/internal/protocol/openai/adapter.go b/backend/internal/protocol/openai/adapter.go deleted file mode 100644 index 708b6c7..0000000 --- a/backend/internal/protocol/openai/adapter.go +++ /dev/null @@ -1,82 +0,0 @@ -package openai - -import ( - "bytes" - "encoding/json" - "io" - "net/http" -) - -// Adapter OpenAI 协议适配器(透传) -type Adapter struct{} - -// NewAdapter 创建 OpenAI 适配器 -func NewAdapter() *Adapter { - return &Adapter{} -} - -// PrepareRequest 准备发送给供应商的请求(透传) -func (a *Adapter) PrepareRequest(req *ChatCompletionRequest, apiKey, baseURL string) (*http.Request, error) { - // 序列化请求体 - body, err := json.Marshal(req) - if err != nil { - return nil, err - } - - // 创建 HTTP 请求 - // baseURL 已包含版本路径(如 /v1 或 /v4),只需添加端点路径 - httpReq, err := http.NewRequest("POST", baseURL+"/chat/completions", bytes.NewReader(body)) - if err != nil { - return nil, err - } - - // 设置请求头 - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+apiKey) - - return httpReq, nil -} - -// ParseResponse 解析供应商响应(透传) -func (a *Adapter) ParseResponse(resp *http.Response) (*ChatCompletionResponse, error) { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - var result ChatCompletionResponse - err = json.Unmarshal(body, &result) - if err != nil { - return nil, err - } - - return &result, nil -} - -// ParseErrorResponse 解析错误响应 -func (a *Adapter) ParseErrorResponse(resp *http.Response) (*ErrorResponse, error) { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - var result ErrorResponse - err = json.Unmarshal(body, &result) - if err != nil { - return nil, err - } - - return &result, nil -} - -// ParseStreamChunk 解析流式响应块 -func (a *Adapter) ParseStreamChunk(data []byte) (*StreamChunk, error) { - var chunk StreamChunk - err := json.Unmarshal(data, &chunk) - if err != nil { - return nil, err - } - return &chunk, nil -} diff --git a/backend/internal/protocol/openai/adapter_test.go b/backend/internal/protocol/openai/adapter_test.go deleted file mode 100644 index 0dce920..0000000 --- a/backend/internal/protocol/openai/adapter_test.go +++ /dev/null @@ -1,190 +0,0 @@ -package openai - -import ( - "bytes" - "encoding/json" - "io" - "net/http" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestAdapter_PrepareRequest(t *testing.T) { - adapter := NewAdapter() - req := &ChatCompletionRequest{ - Model: "gpt-4", - Messages: []Message{ - {Role: "user", Content: "Hello"}, - }, - } - - httpReq, err := adapter.PrepareRequest(req, "test-api-key", "https://api.openai.com/v1") - require.NoError(t, err) - require.NotNil(t, httpReq) - - assert.Equal(t, "POST", httpReq.Method) - assert.Equal(t, "https://api.openai.com/v1/chat/completions", httpReq.URL.String()) - assert.Equal(t, "application/json", httpReq.Header.Get("Content-Type")) - assert.Equal(t, "Bearer test-api-key", httpReq.Header.Get("Authorization")) - - // 验证请求体 - var body ChatCompletionRequest - err = json.NewDecoder(httpReq.Body).Decode(&body) - require.NoError(t, err) - assert.Equal(t, "gpt-4", body.Model) -} - -func TestAdapter_ParseResponse(t *testing.T) { - adapter := NewAdapter() - resp := &ChatCompletionResponse{ - ID: "chatcmpl-123", - Object: "chat.completion", - Created: 1234567890, - Model: "gpt-4", - Choices: []Choice{ - { - Index: 0, - Message: &Message{Role: "assistant", Content: "Hello!"}, - }, - }, - Usage: Usage{PromptTokens: 10, CompletionTokens: 5, TotalTokens: 15}, - } - - body, err := json.Marshal(resp) - require.NoError(t, err) - - httpResp := &http.Response{ - StatusCode: 200, - Body: io.NopCloser(bytes.NewReader(body)), - } - - result, err := adapter.ParseResponse(httpResp) - require.NoError(t, err) - assert.Equal(t, "chatcmpl-123", result.ID) - assert.Equal(t, "gpt-4", result.Model) - require.Len(t, result.Choices, 1) - assert.Equal(t, "Hello!", result.Choices[0].Message.Content) -} - -func TestAdapter_ParseErrorResponse(t *testing.T) { - adapter := NewAdapter() - errResp := &ErrorResponse{ - Error: ErrorDetail{ - Message: "Invalid API key", - Type: "invalid_request_error", - Code: "invalid_api_key", - }, - } - - body, err := json.Marshal(errResp) - require.NoError(t, err) - - httpResp := &http.Response{ - StatusCode: 401, - Body: io.NopCloser(bytes.NewReader(body)), - } - - result, err := adapter.ParseErrorResponse(httpResp) - require.NoError(t, err) - assert.Equal(t, "Invalid API key", result.Error.Message) - assert.Equal(t, "invalid_request_error", result.Error.Type) -} - -func TestAdapter_ParseStreamChunk(t *testing.T) { - adapter := NewAdapter() - chunk := &StreamChunk{ - ID: "chatcmpl-123", - Object: "chat.completion.chunk", - Created: 1234567890, - Model: "gpt-4", - Choices: []StreamChoice{ - { - Index: 0, - Delta: Delta{Content: "Hello"}, - }, - }, - } - - data, err := json.Marshal(chunk) - require.NoError(t, err) - - result, err := adapter.ParseStreamChunk(data) - require.NoError(t, err) - assert.Equal(t, "chatcmpl-123", result.ID) - require.Len(t, result.Choices, 1) - assert.Equal(t, "Hello", result.Choices[0].Delta.Content) -} - -func TestParseToolCallArguments(t *testing.T) { - tests := []struct { - name string - input string - wantErr bool - }{ - {"有效JSON", `{"key": "value"}`, false}, - {"无效JSON", `not json`, true}, - {"空JSON", `{}`, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tc := &ToolCall{ - Function: FunctionCall{Arguments: tt.input}, - } - args, err := tc.ParseToolCallArguments() - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.NotNil(t, args) - } - }) - } -} - -func TestSerializeToolCallArguments(t *testing.T) { - args := map[string]interface{}{"key": "value"} - result, err := SerializeToolCallArguments(args) - require.NoError(t, err) - assert.JSONEq(t, `{"key": "value"}`, result) -} - -func TestValidateRequest(t *testing.T) { - t.Run("有效请求", func(t *testing.T) { - req := &ChatCompletionRequest{ - Model: "gpt-4", - Messages: []Message{{Role: "user", Content: "hello"}}, - } - errs := ValidateRequest(req) - assert.Nil(t, errs) - }) - - t.Run("缺少模型", func(t *testing.T) { - req := &ChatCompletionRequest{ - Messages: []Message{{Role: "user", Content: "hello"}}, - } - errs := ValidateRequest(req) - assert.NotNil(t, errs) - assert.Contains(t, errs["model"], "不能为空") - }) - - t.Run("缺少消息", func(t *testing.T) { - req := &ChatCompletionRequest{ - Model: "gpt-4", - } - errs := ValidateRequest(req) - assert.NotNil(t, errs) - assert.Contains(t, errs["messages"], "不能为空") - }) - - t.Run("空消息列表", func(t *testing.T) { - req := &ChatCompletionRequest{ - Model: "gpt-4", - Messages: []Message{}, - } - errs := ValidateRequest(req) - assert.NotNil(t, errs) - }) -} diff --git a/backend/internal/protocol/openai/types.go b/backend/internal/protocol/openai/types.go deleted file mode 100644 index b181dc5..0000000 --- a/backend/internal/protocol/openai/types.go +++ /dev/null @@ -1,160 +0,0 @@ -package openai - -import ( - "encoding/json" - "fmt" - - "github.com/go-playground/validator/v10" - - pkgValidator "nex/backend/pkg/validator" -) - -// ChatCompletionRequest OpenAI Chat Completions API 请求结构 -type ChatCompletionRequest struct { - Model string `json:"model" validate:"required"` - Messages []Message `json:"messages" validate:"required,min=1"` - Temperature *float64 `json:"temperature,omitempty"` - MaxTokens *int `json:"max_tokens,omitempty"` - TopP *float64 `json:"top_p,omitempty"` - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` - PresencePenalty *float64 `json:"presence_penalty,omitempty"` - Stop interface{} `json:"stop,omitempty"` // 可以是字符串或字符串数组 - N *int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Tools []Tool `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` // 可以是字符串或对象 - User string `json:"user,omitempty"` -} - -// Message OpenAI 消息结构 -type Message struct { - Role string `json:"role"` - Content interface{} `json:"content"` // 可以是字符串或数组(多模态,MVP不支持) - Name string `json:"name,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` // 用于 role="tool" 的消息 -} - -// Tool OpenAI 工具定义 -type Tool struct { - Type string `json:"type"` // 目前只有 "function" - Function FunctionDefinition `json:"function"` -} - -// FunctionDefinition 函数定义 -type FunctionDefinition struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Parameters map[string]interface{} `json:"parameters,omitempty"` -} - -// ToolCall 工具调用 -type ToolCall struct { - ID string `json:"id"` - Type string `json:"type"` // "function" - Function FunctionCall `json:"function"` -} - -// FunctionCall 函数调用 -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` // JSON 字符串 -} - -// ChatCompletionResponse OpenAI Chat Completions API 响应结构 -type ChatCompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []Choice `json:"choices"` - Usage Usage `json:"usage"` -} - -// Choice 响应选项 -type Choice struct { - Index int `json:"index"` - Message *Message `json:"message,omitempty"` - Delta *Delta `json:"delta,omitempty"` // 用于流式响应 - FinishReason string `json:"finish_reason"` -} - -// Delta 流式响应增量 -type Delta struct { - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` -} - -// Usage Token 使用统计 -type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -// StreamChunk 流式响应块 -type StreamChunk struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []StreamChoice `json:"choices"` -} - -// StreamChoice 流式响应选项 -type StreamChoice struct { - Index int `json:"index"` - Delta Delta `json:"delta"` - FinishReason string `json:"finish_reason,omitempty"` -} - -// ErrorResponse OpenAI 错误响应 -type ErrorResponse struct { - Error ErrorDetail `json:"error"` -} - -// ErrorDetail 错误详情 -type ErrorDetail struct { - Message string `json:"message"` - Type string `json:"type,omitempty"` - Code string `json:"code,omitempty"` -} - -// ParseToolCallArguments 解析 tool_call 的 arguments(从 JSON 字符串转为 map) -func (tc *ToolCall) ParseToolCallArguments() (map[string]interface{}, error) { - var args map[string]interface{} - err := json.Unmarshal([]byte(tc.Function.Arguments), &args) - return args, err -} - -// SerializeToolCallArguments 序列化 tool_call 的 arguments(从 map 转为 JSON 字符串) -func SerializeToolCallArguments(args map[string]interface{}) (string, error) { - bytes, err := json.Marshal(args) - if err != nil { - return "", err - } - return string(bytes), nil -} - -// ValidateRequest 验证 ChatCompletionRequest -func ValidateRequest(req *ChatCompletionRequest) map[string]string { - errs := pkgValidator.Validate(req) - if errs == nil { - return nil - } - - validationErrors := make(map[string]string) - for _, err := range errs.(validator.ValidationErrors) { - field := err.Field() - switch field { - case "Model": - validationErrors["model"] = "模型名称不能为空" - case "Messages": - validationErrors["messages"] = "消息列表不能为空" - default: - validationErrors[field] = fmt.Sprintf("字段 %s 验证失败: %s", field, err.Tag()) - } - } - return validationErrors -} diff --git a/backend/internal/provider/client.go b/backend/internal/provider/client.go index abf1edf..fafcda5 100644 --- a/backend/internal/provider/client.go +++ b/backend/internal/provider/client.go @@ -11,15 +11,15 @@ import ( "go.uber.org/zap" - "nex/backend/internal/protocol/openai" + "nex/backend/internal/conversion" ) // StreamConfig 流式处理配置 type StreamConfig struct { - InitialBufferSize int // 初始缓冲区大小(字节),默认 4096 - MaxBufferSize int // 最大缓冲区大小(字节),默认 65536 - Timeout time.Duration // 流超时时间,默认 5 分钟 - ChannelBufferSize int // 事件通道缓冲区大小,默认 100 + InitialBufferSize int + MaxBufferSize int + Timeout time.Duration + ChannelBufferSize int } // DefaultStreamConfig 返回默认流式处理配置 @@ -32,14 +32,6 @@ func DefaultStreamConfig() StreamConfig { } } -// Client OpenAI 兼容供应商客户端 -type Client struct { - httpClient *http.Client - adapter *openai.Adapter - logger *zap.Logger - streamCfg StreamConfig -} - // StreamEvent 流事件 type StreamEvent struct { Data []byte @@ -47,10 +39,17 @@ type StreamEvent struct { Done bool } +// Client 协议无关的供应商客户端 +type Client struct { + httpClient *http.Client + logger *zap.Logger + streamCfg StreamConfig +} + // ProviderClient 供应商客户端接口 type ProviderClient interface { - SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) - SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan StreamEvent, error) + Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) + SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) } // NewClient 创建供应商客户端 @@ -59,97 +58,98 @@ func NewClient() *Client { httpClient: &http.Client{ Timeout: 30 * time.Second, }, - adapter: openai.NewAdapter(), logger: zap.L(), streamCfg: DefaultStreamConfig(), } } -// SendRequest 发送非流式请求 -func (c *Client) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) { - // 准备请求 - httpReq, err := c.adapter.PrepareRequest(req, apiKey, baseURL) +// Send 发送非流式请求 +func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) { + var bodyReader io.Reader + if len(spec.Body) > 0 { + bodyReader = bytes.NewReader(spec.Body) + } + + httpReq, err := http.NewRequestWithContext(ctx, spec.Method, spec.URL, bodyReader) if err != nil { - return nil, fmt.Errorf("准备请求失败: %w", err) + return nil, fmt.Errorf("创建请求失败: %w", err) + } + + for k, v := range spec.Headers { + httpReq.Header.Set(k, v) } c.logger.Debug("发送请求", - zap.String("url", httpReq.URL.String()), - zap.String("method", httpReq.Method), + zap.String("url", spec.URL), + zap.String("method", spec.Method), ) - // 设置上下文 - httpReq = httpReq.WithContext(ctx) - - // 发送请求 resp, err := c.httpClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("发送请求失败: %w", err) } + defer resp.Body.Close() - // 检查状态码 - if resp.StatusCode != http.StatusOK { - // 解析错误响应 - errorResp, parseErr := c.adapter.ParseErrorResponse(resp) - if parseErr != nil { - return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode) - } - return nil, fmt.Errorf("供应商错误: %s", errorResp.Error.Message) - } - - // 解析响应 - result, err := c.adapter.ParseResponse(resp) + respBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("解析响应失败: %w", err) + return nil, fmt.Errorf("读取响应失败: %w", err) } - return result, nil + respHeaders := make(map[string]string) + for k, vs := range resp.Header { + if len(vs) > 0 { + respHeaders[k] = vs[0] + } + } + + return &conversion.HTTPResponseSpec{ + StatusCode: resp.StatusCode, + Headers: respHeaders, + Body: respBody, + }, nil } -// SendStreamRequest 发送流式请求 -func (c *Client) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan StreamEvent, error) { - // 确保请求设置为流式 - req.Stream = true - - // 准备请求 - httpReq, err := c.adapter.PrepareRequest(req, apiKey, baseURL) - if err != nil { - return nil, fmt.Errorf("准备请求失败: %w", err) +// SendStream 发送流式请求 +func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) { + var bodyReader io.Reader + if len(spec.Body) > 0 { + bodyReader = bytes.NewReader(spec.Body) } - // 设置带超时的上下文 streamCtx, cancel := context.WithTimeout(ctx, c.streamCfg.Timeout) - _ = cancel // cancel 在流读取结束后由 ctx 传播处理 - httpReq = httpReq.WithContext(streamCtx) + httpReq, err := http.NewRequestWithContext(streamCtx, spec.Method, spec.URL, bodyReader) + if err != nil { + cancel() + return nil, fmt.Errorf("创建请求失败: %w", err) + } + + for k, v := range spec.Headers { + httpReq.Header.Set(k, v) + } - // 发送请求 resp, err := c.httpClient.Do(httpReq) if err != nil { cancel() return nil, fmt.Errorf("发送请求失败: %w", err) } - // 检查状态码 if resp.StatusCode != http.StatusOK { defer resp.Body.Close() cancel() - errorResp, parseErr := c.adapter.ParseErrorResponse(resp) - if parseErr != nil { - return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode) + errBody, _ := io.ReadAll(resp.Body) + if len(errBody) > 0 { + return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody)) } - return nil, fmt.Errorf("供应商错误: %s", errorResp.Error.Message) + return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode) } - // 创建事件通道 eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize) - - // 启动 goroutine 读取流 go c.readStream(streamCtx, cancel, resp.Body, eventChan) return eventChan, nil } -// readStream 读取 SSE 流(支持动态缓冲区、超时控制和改进的错误处理) +// readStream 读取 SSE 流 func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body io.ReadCloser, eventChan chan<- StreamEvent) { defer close(eventChan) defer body.Close() @@ -175,10 +175,8 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body n, err := body.Read(buf) if err != nil { if err == io.EOF { - // 流正常结束 return } - // 区分网络错误和其他错误 if isNetworkError(err) { c.logger.Error("流网络错误", zap.String("error", err.Error())) eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)} @@ -191,7 +189,6 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body dataBuf = append(dataBuf, buf[:n]...) - // 动态调整缓冲区大小:如果数据量大,增大缓冲区 if len(dataBuf) > bufSize/2 && bufSize < c.streamCfg.MaxBufferSize { newSize := bufSize * 2 if newSize > c.streamCfg.MaxBufferSize { @@ -201,34 +198,21 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body bufSize = newSize } - // 处理完整的 SSE 事件 for { - // 查找事件边界(双换行) idx := bytes.Index(dataBuf, []byte("\n\n")) if idx == -1 { break } - // 提取事件 - event := dataBuf[:idx] + rawEvent := dataBuf[:idx] dataBuf = dataBuf[idx+2:] - // 解析 data 行 - lines := strings.Split(string(event), "\n") - for _, line := range lines { - if strings.HasPrefix(line, "data: ") { - data := strings.TrimPrefix(line, "data: ") - - // 检查是否是结束标记 - if data == "[DONE]" { - eventChan <- StreamEvent{Done: true} - return - } - - // 发送数据 - eventChan <- StreamEvent{Data: []byte(data)} - } + if bytes.Contains(rawEvent, []byte("data: [DONE]")) { + eventChan <- StreamEvent{Done: true} + return } + + eventChan <- StreamEvent{Data: rawEvent} } } } @@ -245,4 +229,3 @@ func isNetworkError(err error) bool { strings.Contains(errStr, "timeout") || strings.Contains(errStr, "EOF") } - diff --git a/backend/internal/provider/client_test.go b/backend/internal/provider/client_test.go index 66d56bb..5661b92 100644 --- a/backend/internal/provider/client_test.go +++ b/backend/internal/provider/client_test.go @@ -2,7 +2,6 @@ package provider import ( "context" - "encoding/json" "fmt" "net/http" "net/http/httptest" @@ -11,14 +10,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "nex/backend/internal/protocol/openai" + "nex/backend/internal/conversion" ) func TestNewClient(t *testing.T) { client := NewClient() require.NotNil(t, client) assert.NotNil(t, client.httpClient) - assert.NotNil(t, client.adapter) assert.Equal(t, 4096, client.streamCfg.InitialBufferSize) assert.Equal(t, 65536, client.streamCfg.MaxBufferSize) assert.Equal(t, 100, client.streamCfg.ChannelBufferSize) @@ -31,67 +29,66 @@ func TestDefaultStreamConfig(t *testing.T) { assert.Equal(t, 100, cfg.ChannelBufferSize) } -func TestClient_SendRequest_Success(t *testing.T) { +func TestClient_Send_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "POST", r.Method) assert.Equal(t, "application/json", r.Header.Get("Content-Type")) assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) - - resp := openai.ChatCompletionResponse{ - ID: "chatcmpl-123", - Choices: []openai.Choice{ - {Index: 0, Message: &openai.Message{Role: "assistant", Content: "Hello!"}}, - }, - } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"id":"test","model":"gpt-4"}`)) })) defer server.Close() client := NewClient() - req := &openai.ChatCompletionRequest{ - Model: "gpt-4", - Messages: []openai.Message{{Role: "user", Content: "Hi"}}, + spec := conversion.HTTPRequestSpec{ + URL: server.URL + "/v1/chat/completions", + Method: "POST", + Headers: map[string]string{ + "Authorization": "Bearer test-key", + "Content-Type": "application/json", + }, + Body: []byte(`{"model":"gpt-4","messages":[]}`), } - result, err := client.SendRequest(context.Background(), req, "test-key", server.URL) + result, err := client.Send(context.Background(), spec) require.NoError(t, err) - assert.Equal(t, "chatcmpl-123", result.ID) + assert.Equal(t, 200, result.StatusCode) + assert.Contains(t, string(result.Body), "test") } -func TestClient_SendRequest_ErrorResponse(t *testing.T) { +func TestClient_Send_ErrorResponse(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) - json.NewEncoder(w).Encode(openai.ErrorResponse{ - Error: openai.ErrorDetail{Message: "Invalid API key"}, - }) + w.Write([]byte(`{"error":{"message":"Invalid API key"}}`)) })) defer server.Close() client := NewClient() - req := &openai.ChatCompletionRequest{ - Model: "gpt-4", - Messages: []openai.Message{{Role: "user", Content: "Hi"}}, + spec := conversion.HTTPRequestSpec{ + URL: server.URL + "/v1/chat/completions", + Method: "POST", + Headers: map[string]string{"Authorization": "Bearer bad-key"}, + Body: []byte(`{}`), } - _, err := client.SendRequest(context.Background(), req, "bad-key", server.URL) - assert.Error(t, err) - assert.Contains(t, err.Error(), "Invalid API key") + result, err := client.Send(context.Background(), spec) + require.NoError(t, err) + assert.Equal(t, 401, result.StatusCode) } -func TestClient_SendRequest_ConnectionError(t *testing.T) { +func TestClient_Send_ConnectionError(t *testing.T) { client := NewClient() - req := &openai.ChatCompletionRequest{ - Model: "gpt-4", - Messages: []openai.Message{{Role: "user", Content: "Hi"}}, + spec := conversion.HTTPRequestSpec{ + URL: "http://localhost:1/v1/chat/completions", + Method: "POST", } - _, err := client.SendRequest(context.Background(), req, "key", "http://localhost:1") + _, err := client.Send(context.Background(), spec) assert.Error(t, err) } -func TestClient_SendStreamRequest_CreatesChannel(t *testing.T) { - // 使用一个慢服务器确保客户端有时间读取 +func TestClient_SendStream_CreatesChannel(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) @@ -99,35 +96,36 @@ func TestClient_SendStreamRequest_CreatesChannel(t *testing.T) { defer server.Close() client := NewClient() - req := &openai.ChatCompletionRequest{ - Model: "gpt-4", - Messages: []openai.Message{{Role: "user", Content: "Hi"}}, + spec := conversion.HTTPRequestSpec{ + URL: server.URL + "/v1/chat/completions", + Method: "POST", + Headers: map[string]string{"Authorization": "Bearer test-key"}, + Body: []byte(`{}`), } - eventChan, err := client.SendStreamRequest(context.Background(), req, "test-key", server.URL) + eventChan, err := client.SendStream(context.Background(), spec) require.NoError(t, err) require.NotNil(t, eventChan) - // 读取直到 channel 关闭(服务器关闭后应产生 EOF) for range eventChan { - // 消费所有事件 } - // channel 应已关闭(不阻塞即通过) } -func TestClient_SendStreamRequest_ErrorResponse(t *testing.T) { +func TestClient_SendStream_ErrorResponse(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() client := NewClient() - req := &openai.ChatCompletionRequest{ - Model: "gpt-4", - Messages: []openai.Message{{Role: "user", Content: "Hi"}}, + spec := conversion.HTTPRequestSpec{ + URL: server.URL + "/v1/chat/completions", + Method: "POST", + Headers: map[string]string{"Authorization": "Bearer key"}, + Body: []byte(`{}`), } - _, err := client.SendStreamRequest(context.Background(), req, "key", server.URL) + _, err := client.SendStream(context.Background(), spec) assert.Error(t, err) } @@ -145,7 +143,7 @@ func TestIsNetworkError(t *testing.T) { {"", false}, } for _, tt := range tests { - err := fmt.Errorf("%s", tt.input) //nolint:govet + err := fmt.Errorf("%s", tt.input) assert.Equal(t, tt.want, isNetworkError(err), "isNetworkError(%q)", tt.input) } } diff --git a/backend/internal/repository/provider_repo_impl.go b/backend/internal/repository/provider_repo_impl.go index 45b7501..6ea917b 100644 --- a/backend/internal/repository/provider_repo_impl.go +++ b/backend/internal/repository/provider_repo_impl.go @@ -77,6 +77,7 @@ func toDomainProvider(p *config.Provider) domain.Provider { Name: p.Name, APIKey: p.APIKey, BaseURL: p.BaseURL, + Protocol: p.Protocol, Enabled: p.Enabled, CreatedAt: p.CreatedAt, UpdatedAt: p.UpdatedAt, @@ -85,10 +86,11 @@ func toDomainProvider(p *config.Provider) domain.Provider { func toConfigProvider(p *domain.Provider) config.Provider { return config.Provider{ - ID: p.ID, - Name: p.Name, - APIKey: p.APIKey, - BaseURL: p.BaseURL, - Enabled: p.Enabled, + ID: p.ID, + Name: p.Name, + APIKey: p.APIKey, + BaseURL: p.BaseURL, + Protocol: p.Protocol, + Enabled: p.Enabled, } } diff --git a/backend/migrations/001_initial_schema.sql b/backend/migrations/20260401000001_initial_schema.sql similarity index 100% rename from backend/migrations/001_initial_schema.sql rename to backend/migrations/20260401000001_initial_schema.sql diff --git a/backend/migrations/002_add_indexes.sql b/backend/migrations/20260401000002_add_indexes.sql similarity index 100% rename from backend/migrations/002_add_indexes.sql rename to backend/migrations/20260401000002_add_indexes.sql diff --git a/backend/migrations/20260419000001_add_provider_protocol.sql b/backend/migrations/20260419000001_add_provider_protocol.sql new file mode 100644 index 0000000..6ed08b7 --- /dev/null +++ b/backend/migrations/20260419000001_add_provider_protocol.sql @@ -0,0 +1,6 @@ +-- +goose Up +ALTER TABLE providers ADD COLUMN protocol TEXT DEFAULT 'openai'; + +-- +goose Down +-- SQLite 不支持 DROP COLUMN(3.35.0 之前),但 goose 的 Down 通常不需要 +CREATE TABLE providers_backup AS SELECT id, name, api_key, base_url, enabled, created_at, updated_at FROM providers; diff --git a/backend/tests/integration/conversion_test.go b/backend/tests/integration/conversion_test.go new file mode 100644 index 0000000..c866505 --- /dev/null +++ b/backend/tests/integration/conversion_test.go @@ -0,0 +1,571 @@ +package integration + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + + "nex/backend/internal/config" + "nex/backend/internal/conversion" + "nex/backend/internal/conversion/anthropic" + openaiConv "nex/backend/internal/conversion/openai" + "nex/backend/internal/handler" + "nex/backend/internal/handler/middleware" + "nex/backend/internal/provider" + "nex/backend/internal/repository" + "nex/backend/internal/service" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +// setupConversionTest 创建包含 ConversionEngine 的完整测试环境 +func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server) { + t.Helper() + + // 创建 mock 上游服务器 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 默认返回成功,由各测试 case 覆盖 + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"error":"not mocked"}`)) + })) + + dir := t.TempDir() + db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{}) + require.NoError(t, err) + err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) + require.NoError(t, err) + t.Cleanup(func() { + sqlDB, _ := db.DB() + if sqlDB != nil { + sqlDB.Close() + } + upstream.Close() + }) + + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + statsRepo := repository.NewStatsRepository(db) + + providerService := service.NewProviderService(providerRepo) + modelService := service.NewModelService(modelRepo, providerRepo) + routingService := service.NewRoutingService(modelRepo, providerRepo) + statsService := service.NewStatsService(statsRepo) + + // 创建 ConversionEngine + registry := conversion.NewMemoryRegistry() + require.NoError(t, registry.Register(openaiConv.NewAdapter())) + require.NoError(t, registry.Register(anthropic.NewAdapter())) + engine := conversion.NewConversionEngine(registry) + + providerClient := provider.NewClient() + proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService) + providerHandler := handler.NewProviderHandler(providerService) + modelHandler := handler.NewModelHandler(modelService) + statsHandler := handler.NewStatsHandler(statsService) + + _ = modelService + + r := gin.New() + r.Use(middleware.CORS()) + + // 代理路由 + r.Any("/:protocol/v1/*path", proxyHandler.HandleProxy) + + // 管理路由 + providers := r.Group("/api/providers") + { + providers.GET("", providerHandler.ListProviders) + providers.POST("", providerHandler.CreateProvider) + providers.GET("/:id", providerHandler.GetProvider) + providers.PUT("/:id", providerHandler.UpdateProvider) + providers.DELETE("/:id", providerHandler.DeleteProvider) + } + models := r.Group("/api/models") + { + models.GET("", modelHandler.ListModels) + models.POST("", modelHandler.CreateModel) + models.GET("/:id", modelHandler.GetModel) + models.PUT("/:id", modelHandler.UpdateModel) + models.DELETE("/:id", modelHandler.DeleteModel) + } + _ = statsHandler + + return r, db, upstream +} + +// createProviderAndModel 辅助:创建供应商和模型 +func createProviderAndModel(t *testing.T, r *gin.Engine, providerID, protocol, modelName string, upstreamURL string) { + t.Helper() + + providerBody, _ := json.Marshal(map[string]string{ + "id": providerID, + "name": providerID, + "api_key": "test-key", + "base_url": upstreamURL, + "protocol": protocol, + }) + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + require.Equal(t, 201, w.Code) + + modelBody, _ := json.Marshal(map[string]string{ + "id": modelName, + "provider_id": providerID, + "model_name": modelName, + }) + w = httptest.NewRecorder() + req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + require.Equal(t, 201, w.Code) +} + +// ============ 跨协议非流式转换测试 ============ + +func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) { + r, _, upstream := setupConversionTest(t) + + // 配置上游返回 Anthropic 格式响应 + upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证请求被转换为 Anthropic 格式 + body, _ := io.ReadAll(r.Body) + var req map[string]any + json.Unmarshal(body, &req) + + assert.Equal(t, "/v1/messages", r.URL.Path) + assert.Contains(t, r.Header.Get("Content-Type"), "application/json") + + // 返回 Anthropic 响应 + resp := map[string]any{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": "claude-3-opus", + "content": []map[string]any{ + {"type": "text", "text": "Hello from Anthropic!"}, + }, + "stop_reason": "end_turn", + "usage": map[string]any{ + "input_tokens": 10, + "output_tokens": 20, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + }) + + createProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-3-opus", upstream.URL) + + // 使用 OpenAI 格式发送请求 + openaiReq := map[string]any{ + "model": "claude-3-opus", + "messages": []map[string]any{ + {"role": "user", "content": "Hello"}, + }, + "stream": false, + } + body, _ := json.Marshal(openaiReq) + + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + + var resp map[string]any + json.Unmarshal(w.Body.Bytes(), &resp) + assert.Equal(t, "chat.completion", resp["object"]) + + choices := resp["choices"].([]any) + require.Len(t, choices, 1) + choice := choices[0].(map[string]any) + msg := choice["message"].(map[string]any) + assert.Contains(t, msg["content"], "Hello from Anthropic!") +} + +func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) { + r, _, upstream := setupConversionTest(t) + + upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]any + json.Unmarshal(body, &req) + + assert.Equal(t, "/v1/chat/completions", r.URL.Path) + assert.Contains(t, r.Header.Get("Authorization"), "Bearer test-key") + + resp := map[string]any{ + "id": "chatcmpl-test", + "object": "chat.completion", + "model": "gpt-4", + "created": time.Now().Unix(), + "choices": []map[string]any{ + { + "index": 0, + "message": map[string]any{"role": "assistant", "content": "Hello from OpenAI!"}, + "finish_reason": "stop", + }, + }, + "usage": map[string]any{ + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + }) + + createProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) + + anthropicReq := map[string]any{ + "model": "gpt-4", + "max_tokens": 1024, + "messages": []map[string]any{ + {"role": "user", "content": "Hello"}, + }, + "stream": false, + } + body, _ := json.Marshal(anthropicReq) + + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + + var resp map[string]any + json.Unmarshal(w.Body.Bytes(), &resp) + assert.Equal(t, "message", resp["type"]) + + content := resp["content"].([]any) + require.Len(t, content, 1) + block := content[0].(map[string]any) + assert.Contains(t, block["text"], "Hello from OpenAI!") +} + +// ============ 同协议透传测试 ============ + +func TestConversion_OpenAIToOpenAI_Passthrough(t *testing.T) { + r, _, upstream := setupConversionTest(t) + + upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/v1/chat/completions", r.URL.Path) + + body, _ := io.ReadAll(r.Body) + var req map[string]any + json.Unmarshal(body, &req) + assert.Equal(t, "gpt-4", req["model"]) + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`)) + }) + + createProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) + + reqBody := map[string]any{ + "model": "gpt-4", + "messages": []map[string]any{{"role": "user", "content": "test"}}, + } + body, _ := json.Marshal(reqBody) + + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Contains(t, w.Body.String(), "passthrough") +} + +func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) { + r, _, upstream := setupConversionTest(t) + + upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/v1/messages", r.URL.Path) + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`)) + }) + + createProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-3-opus", upstream.URL) + + reqBody := map[string]any{ + "model": "claude-3-opus", + "max_tokens": 1024, + "messages": []map[string]any{{"role": "user", "content": "test"}}, + } + body, _ := json.Marshal(reqBody) + + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Contains(t, w.Body.String(), "passthrough") +} + +// ============ 流式转换测试 ============ + +func TestConversion_OpenAIToAnthropic_Stream(t *testing.T) { + r, _, upstream := setupConversionTest(t) + + upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + events := []string{ + "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"model\":\"claude-3-opus\",\"usage\":{\"input_tokens\":10,\"output_tokens\":0}}}\n\n", + "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n", + "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hi\"}}\n\n", + "event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n", + "event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":5}}\n\n", + "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", + } + for _, e := range events { + w.Write([]byte(e)) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + }) + + createProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-3-opus", upstream.URL) + + openaiReq := map[string]any{ + "model": "claude-3-opus", + "messages": []map[string]any{{"role": "user", "content": "Hello"}}, + "stream": true, + } + body, _ := json.Marshal(openaiReq) + + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + ct := w.Header().Get("Content-Type") + assert.Contains(t, ct, "text/event-stream") +} + +func TestConversion_AnthropicToOpenAI_Stream(t *testing.T) { + r, _, upstream := setupConversionTest(t) + + upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + events := []string{ + fmt.Sprintf("data: {\"id\":\"chatcmpl-s\",\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\"}}]}\n\n"), + "data: {\"id\":\"chatcmpl-s\",\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hey\"}}]}\n\n", + "data: {\"id\":\"chatcmpl-s\",\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"!\"}}]}\n\n", + "data: {\"id\":\"chatcmpl-s\",\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n", + "data: [DONE]\n\n", + } + for _, e := range events { + w.Write([]byte(e)) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + }) + + createProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) + + anthropicReq := map[string]any{ + "model": "gpt-4", + "max_tokens": 1024, + "messages": []map[string]any{{"role": "user", "content": "Hello"}}, + "stream": true, + } + body, _ := json.Marshal(anthropicReq) + + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/anthropic/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + ct := w.Header().Get("Content-Type") + assert.Contains(t, ct, "text/event-stream") +} + +// ============ Models 接口测试 ============ + +func TestConversion_Models_CrossProtocol(t *testing.T) { + // 测试 Models 接口跨协议转换的编解码逻辑 + // 由于 GET /models 无 body 无法路由,此处测试 adapter 级别的编解码 + + registry := conversion.NewMemoryRegistry() + require.NoError(t, registry.Register(openaiConv.NewAdapter())) + require.NoError(t, registry.Register(anthropic.NewAdapter())) + + openaiAdapter, _ := registry.Get("openai") + anthropicAdapter, _ := registry.Get("anthropic") + + // 模拟 OpenAI 格式的 models 响应 + openaiModelsBody := []byte(`{"object":"list","data":[{"id":"gpt-4","object":"model","created":1700000000,"owned_by":"openai"},{"id":"gpt-3.5-turbo","object":"model","created":1700000001,"owned_by":"openai"}]}`) + + // OpenAI decode → Canonical → Anthropic encode + modelList, err := openaiAdapter.DecodeModelsResponse(openaiModelsBody) + require.NoError(t, err) + assert.Len(t, modelList.Models, 2) + assert.Equal(t, "gpt-4", modelList.Models[0].ID) + + // 编码为 Anthropic 格式 + anthropicBody, err := anthropicAdapter.EncodeModelsResponse(modelList) + require.NoError(t, err) + + var anthropicResp map[string]any + json.Unmarshal(anthropicBody, &anthropicResp) + data := anthropicResp["data"].([]any) + assert.Len(t, data, 2) + + first := data[0].(map[string]any) + assert.Equal(t, "gpt-4", first["id"]) + assert.Equal(t, "model", first["type"]) + + // 反向测试:Anthropic decode → Canonical → OpenAI encode + anthropicModelsBody := []byte(`{"data":[{"id":"claude-3-opus","type":"model","display_name":"Claude 3 Opus","created_at":"2025-01-01T00:00:00Z"}],"has_more":false}`) + modelList2, err := anthropicAdapter.DecodeModelsResponse(anthropicModelsBody) + require.NoError(t, err) + assert.Len(t, modelList2.Models, 1) + assert.Equal(t, "Claude 3 Opus", modelList2.Models[0].Name) + + openaiBody, err := openaiAdapter.EncodeModelsResponse(modelList2) + require.NoError(t, err) + + var openaiResp map[string]any + json.Unmarshal(openaiBody, &err) + json.Unmarshal(openaiBody, &openaiResp) + oaiData := openaiResp["data"].([]any) + assert.Len(t, oaiData, 1) + firstOai := oaiData[0].(map[string]any) + assert.Equal(t, "claude-3-opus", firstOai["id"]) +} + +// ============ 错误响应测试 ============ + +func TestConversion_ErrorResponse_Format(t *testing.T) { + r, _, _ := setupConversionTest(t) + + // 请求不存在的模型 + reqBody := map[string]any{ + "model": "nonexistent", + "messages": []map[string]any{{"role": "user", "content": "test"}}, + } + body, _ := json.Marshal(reqBody) + + // OpenAI 协议格式 + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + assert.True(t, w.Code >= 400) +} + +// ============ 旧路由返回 404 ============ + +func TestConversion_OldRoutes_Return404(t *testing.T) { + r, _, _ := setupConversionTest(t) + + // 旧 OpenAI 路由 + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/v1/chat/completions", strings.NewReader(`{"model":"test"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + // Gin 路由不匹配返回 404 + assert.Equal(t, 404, w.Code) + + // 旧 Anthropic 路由 + w = httptest.NewRecorder() + req = httptest.NewRequest("POST", "/v1/messages", strings.NewReader(`{"model":"test"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + assert.Equal(t, 404, w.Code) +} + +// ============ Provider Protocol 字段测试 ============ + +func TestConversion_ProviderWithProtocol(t *testing.T) { + r, _, _ := setupConversionTest(t) + + // 创建带 protocol 字段的 provider + providerBody := map[string]any{ + "id": "test-protocol", + "name": "Test Protocol", + "api_key": "sk-test", + "base_url": "https://test.com", + "protocol": "anthropic", + } + body, _ := json.Marshal(providerBody) + + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + require.Equal(t, 201, w.Code) + + var created map[string]any + json.Unmarshal(w.Body.Bytes(), &created) + // API Key 被掩码 + assert.Contains(t, created["api_key"], "***") + + // 获取时应包含 protocol + w = httptest.NewRecorder() + req = httptest.NewRequest("GET", "/api/providers/test-protocol", nil) + r.ServeHTTP(w, req) + assert.Equal(t, 200, w.Code) + + var fetched map[string]any + json.Unmarshal(w.Body.Bytes(), &fetched) + assert.Equal(t, "anthropic", fetched["protocol"]) +} + +func TestConversion_ProviderDefaultProtocol(t *testing.T) { + r, _, _ := setupConversionTest(t) + + // 不指定 protocol,默认应为 openai + providerBody := map[string]any{ + "id": "default-proto", + "name": "Default", + "api_key": "sk-test", + "base_url": "https://test.com", + } + body, _ := json.Marshal(providerBody) + + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + require.Equal(t, 201, w.Code) + + var created map[string]any + json.Unmarshal(w.Body.Bytes(), &created) + assert.Equal(t, "openai", created["protocol"]) +} + +// Suppress unused imports +var _ = fmt.Sprintf +var _ = strings.Contains +var _ = time.Second diff --git a/openspec/changes/refactor-conversion-engine/design.md b/openspec/changes/refactor-conversion-engine/design.md index 23449fe..8085e1e 100644 --- a/openspec/changes/refactor-conversion-engine/design.md +++ b/openspec/changes/refactor-conversion-engine/design.md @@ -180,14 +180,14 @@ type ProviderClient interface { - 路由时需要知道 providerProtocol 以选择正确的 Adapter - 默认值 `'openai'` 确保现有数据兼容 -### D7: 删除旧 `internal/protocol/` 包,在 `internal/conversion/` 中重建 +### D7: 删除旧 `internal/protocol/` 包,在 `internal/conversion/` 中全新实现 -**选择**:直接删除 `internal/protocol/openai/` 和 `internal/protocol/anthropic/`,在 `internal/conversion/` 下从零构建新架构 +**选择**:直接删除 `internal/protocol/openai/` 和 `internal/protocol/anthropic/`,在 `internal/conversion/` 下对照设计文档全新编写所有代码 **理由**: -- 旧代码的设计模式(OpenAI 类型为枢纽)与新架构根本不同 +- 旧代码的设计模式(OpenAI 类型为枢纽)与新架构根本不同,无法复用 - 保留旧代码容易导致混用两种模式,引入隐蔽 bug -- 旧代码中的类型定义可以迁移(copy-paste),但组织方式需重建 +- 旧代码中的类型定义不迁移,直接根据设计文档重新定义,确保与新架构一致 ### D8: 目标包结构 @@ -206,7 +206,7 @@ internal/conversion/ engine.go # ConversionEngine 门面 + HTTPRequestSpec/HTTPResponseSpec openai/ - types.go # OpenAI 线路格式类型(从旧 protocol/openai/types.go 迁移并补全) + types.go # OpenAI 线路格式类型(对照 conversion_openai.md 全新定义) adapter.go # ProtocolAdapter 实现(detectInterfaceType/buildUrl/buildHeaders/supportsInterface/encodeError) decoder.go # decodeRequest/decodeResponse/扩展层 decode 方法 encoder.go # encodeRequest/encodeResponse/扩展层 encode 方法 @@ -214,7 +214,7 @@ internal/conversion/ stream_encoder.go # OpenAIStreamEncoder(缓冲策略) anthropic/ - types.go # Anthropic 线路格式类型(从旧 protocol/anthropic/types.go 迁移并补全) + types.go # Anthropic 线路格式类型(对照 conversion_anthropic.md 全新定义) adapter.go # ProtocolAdapter 实现(detectInterfaceType/buildUrl/buildHeaders/supportsInterface/encodeError) decoder.go # decodeRequest/decodeResponse/扩展层 decode 方法 encoder.go # encodeRequest/encodeResponse/扩展层 encode 方法 @@ -260,14 +260,14 @@ internal/conversion/ ### 步骤 1. **创建 `internal/conversion/` 包**:实现 Layer 1-3(Canonical Model、接口定义、Engine),不改动现有代码 -2. **实现 OpenAI Adapter 和 Anthropic Adapter**:Layer 4-5,在 conversion 包内自包含 +2. **全新实现 OpenAI Adapter 和 Anthropic Adapter**:Layer 4-5,对照设计文档在 conversion 包内全新编写,不沿用旧 protocol 包代码 3. **编写全面测试**:覆盖编解码、流式转换、错误处理、同协议透传 4. **改造 `domain.Provider`**:新增 `Protocol` 字段 5. **创建数据库迁移**:`ALTER TABLE providers ADD COLUMN protocol TEXT DEFAULT 'openai'` 6. **改造 `ProviderClient`**:简化为接受 `HTTPRequestSpec` 的 HTTP 发送器 7. **创建 `ProxyHandler`**:统一代理入口,集成 ConversionEngine 8. **更新 `cmd/server/main.go`**:注册 Adapter、创建 Engine、配置新路由 -9. **删除旧 `internal/protocol/` 包**:确认新架构完全替代后删除 +9. **删除旧 `internal/protocol/` 包**:直接删除,不迁移代码,确认新架构完全替代 10. **更新 README.md**:项目结构、API 接口、路由说明 ### 兼容策略 @@ -279,7 +279,7 @@ internal/conversion/ ### 回滚策略 - Git 分支隔离:在新分支开发,合并前充分测试 -- 旧 `internal/protocol/` 包在确认新架构稳定后再删除 +- 旧 `internal/protocol/` 包在删除前确认新架构所有测试通过,删除后不可恢复旧代码(从 git 历史仍可找回) - 数据库迁移向下兼容(仅 ADD COLUMN) ## Open Questions diff --git a/openspec/changes/refactor-conversion-engine/proposal.md b/openspec/changes/refactor-conversion-engine/proposal.md index 950a2cb..d769625 100644 --- a/openspec/changes/refactor-conversion-engine/proposal.md +++ b/openspec/changes/refactor-conversion-engine/proposal.md @@ -7,14 +7,14 @@ - **引入 Canonical Model**:定义协议无关的 `CanonicalRequest`、`CanonicalResponse`、`CanonicalStreamEvent` 等规范模型,作为所有协议间转换的统一枢纽 - **引入 ConversionEngine**:无状态的转换引擎门面,协调 Adapter 注册、接口识别、透传判断、请求/响应转换、流式转换 - **引入 ProtocolAdapter 接口**:统一适配器契约,每种协议实现完整的编解码(Chat 请求/响应、流式、扩展层接口、错误编码) -- **实现 OpenAI Adapter**:对照 `docs/conversion_openai.md` 实现 OpenAI 协议的完整 Adapter(含状态机流式解码器/编码器) -- **实现 Anthropic Adapter**:对照 `docs/conversion_anthropic.md` 实现 Anthropic 协议的完整 Adapter(含命名事件流式解码器/编码器) +- **实现 OpenAI Adapter**:对照 `docs/conversion_openai.md` 全新实现 OpenAI 协议的完整 Adapter(含状态机流式解码器/编码器),不沿用旧 `internal/protocol/openai/` 代码 +- **实现 Anthropic Adapter**:对照 `docs/conversion_anthropic.md` 全新实现 Anthropic 协议的完整 Adapter(含命名事件流式解码器/编码器),不沿用旧 `internal/protocol/anthropic/` 代码 - **统一代理 Handler**:合并 `OpenAIHandler` 和 `AnthropicHandler` 为统一的 `ProxyHandler`,支持 `/{protocol}/v1/...` URL 前缀路由 - **同协议透传**:client == provider 时跳过 Canonical 转换,仅重建 Header 后原样转发 - **接口分层**:核心层(Chat)走 Canonical 深度转换,扩展层(Models/Embeddings/Rerank)走轻量映射,未知接口走透传 - **ProviderClient 简化**:移除 OpenAI Adapter 硬编码,变为协议无关的 HTTP 发送器 - **Provider 新增 Protocol 字段**:**BREAKING** — Provider 模型新增 `protocol` 字段标识上游协议类型 -- **删除旧 protocol 包**:移除 `internal/protocol/openai/` 和 `internal/protocol/anthropic/`,全部逻辑迁入 `internal/conversion/` +- **删除旧 protocol 包**:移除 `internal/protocol/openai/` 和 `internal/protocol/anthropic/`,在 `internal/conversion/` 中全新实现 - **URL 路由变更**:**BREAKING** — 代理端点从 `/v1/chat/completions` + `/v1/messages` 变更为 `/{protocol}/v1/...`,不保留旧路由 ## Capabilities @@ -37,7 +37,7 @@ ## Impact -- **代码结构**:新增 `internal/conversion/` 包(约 20+ 文件),删除 `internal/protocol/` 包,改造 `internal/handler/` 和 `internal/provider/` +- **代码结构**:新增 `internal/conversion/` 包(约 20+ 文件,全新编写),删除 `internal/protocol/` 包(不迁移,直接删除后重写),改造 `internal/handler/` 和 `internal/provider/` - **API 兼容性**:**BREAKING** — 代理端点 URL 变更(`/v1/chat/completions` → `/openai/v1/chat/completions`,`/v1/messages` → `/anthropic/v1/messages`),不保留旧路由 - **数据库**:Provider 表新增 `protocol` 列,需数据库迁移 - **依赖**:无新增外部依赖,复用现有 Go 标准库和已引入的包 diff --git a/openspec/changes/refactor-conversion-engine/specs/protocol-adapter-anthropic/spec.md b/openspec/changes/refactor-conversion-engine/specs/protocol-adapter-anthropic/spec.md index 254843f..a7d1316 100644 --- a/openspec/changes/refactor-conversion-engine/specs/protocol-adapter-anthropic/spec.md +++ b/openspec/changes/refactor-conversion-engine/specs/protocol-adapter-anthropic/spec.md @@ -2,7 +2,7 @@ ### Requirement: 实现 Anthropic ProtocolAdapter -系统 SHALL 实现 Anthropic 协议的完整 ProtocolAdapter,对照 `docs/conversion_anthropic.md`。 +系统 SHALL 全新实现 Anthropic 协议的完整 ProtocolAdapter,对照 `docs/conversion_anthropic.md`。不沿用旧 `internal/protocol/anthropic/` 代码。 - `protocolName()` SHALL 返回 `"anthropic"` - `supportsPassthrough()` SHALL 返回 true diff --git a/openspec/changes/refactor-conversion-engine/specs/protocol-adapter-openai/spec.md b/openspec/changes/refactor-conversion-engine/specs/protocol-adapter-openai/spec.md index c20e8fd..3a4fac8 100644 --- a/openspec/changes/refactor-conversion-engine/specs/protocol-adapter-openai/spec.md +++ b/openspec/changes/refactor-conversion-engine/specs/protocol-adapter-openai/spec.md @@ -2,7 +2,7 @@ ### Requirement: 实现 OpenAI ProtocolAdapter -系统 SHALL 实现 OpenAI 协议的完整 ProtocolAdapter,对照 `docs/conversion_openai.md`。 +系统 SHALL 全新实现 OpenAI 协议的完整 ProtocolAdapter,对照 `docs/conversion_openai.md`。不沿用旧 `internal/protocol/openai/` 代码。 - `protocolName()` SHALL 返回 `"openai"` - `supportsPassthrough()` SHALL 返回 true diff --git a/openspec/changes/refactor-conversion-engine/tasks.md b/openspec/changes/refactor-conversion-engine/tasks.md index 30ad8bd..60fdf12 100644 --- a/openspec/changes/refactor-conversion-engine/tasks.md +++ b/openspec/changes/refactor-conversion-engine/tasks.md @@ -1,49 +1,49 @@ ## 1. 基础类型层 — Canonical Model 和核心类型定义 -- [ ] 1.1 创建 `internal/conversion/errors.go`:定义 ConversionError 结构体(Code, Message, ClientProtocol, ProviderProtocol, InterfaceType, Details, Cause)和 ErrorCode 枚举(INVALID_INPUT, MISSING_REQUIRED_FIELD, INCOMPATIBLE_FEATURE, FIELD_MAPPING_FAILURE, TOOL_CALL_PARSE_ERROR, JSON_PARSE_ERROR, STREAM_STATE_ERROR, UTF8_DECODE_ERROR, PROTOCOL_CONSTRAINT_VIOLATION, ENCODING_FAILURE, INTERFACE_NOT_SUPPORTED),实现 error 接口 -- [ ] 1.2 创建 `internal/conversion/interface.go`:定义 InterfaceType 枚举(CHAT, MODELS, MODEL_INFO, EMBEDDINGS, RERANK) -- [ ] 1.3 创建 `internal/conversion/provider.go`:定义 TargetProvider 结构体(BaseURL, APIKey, ModelName, AdapterConfig map[string]any);编写测试 -- [ ] 1.4 创建 `internal/conversion/canonical/types.go`:定义 CanonicalRequest(model, system, messages, tools, tool_choice, parameters, thinking, stream, user_id, output_format, parallel_tool_use)、CanonicalMessage(role 枚举: system/user/assistant/tool, content []ContentBlock)、ContentBlock(使用 type 字段的 discriminated union:text/tool_use/tool_result/thinking,ToolInput 使用 json.RawMessage)、CanonicalTool(name, description, input_schema)、ToolChoice 联合体(auto/none/any/tool+name)、RequestParameters(max_tokens, temperature, top_p, top_k, frequency_penalty, presence_penalty, stop_sequences)、ThinkingConfig(type: enabled/disabled/adaptive, budget_tokens, effort)、OutputFormat(json_object/json_schema+schema/text)、CanonicalResponse(id, model, content, stop_reason 枚举, usage)、CanonicalUsage(input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens, reasoning_tokens)、SystemBlock(text);编写构造和序列化测试 -- [ ] 1.5 创建 `internal/conversion/canonical/stream.go`:定义 CanonicalStreamEvent 联合体(message_start, content_block_start, content_block_delta, content_block_stop, message_delta, message_stop, error, ping)及各事件的具体结构(MessageStartEvent 含 message{id,model,usage}、ContentBlockStartEvent 含 index 和 content_block、ContentBlockDeltaEvent 含 index 和 delta、ContentBlockStopEvent 含 index、MessageDeltaEvent 含 delta{stop_reason} 和 usage、MessageStopEvent、ErrorEvent、PingEvent),delta 联合体(text_delta, input_json_delta, thinking_delta),content_block 联合体(text, tool_use, thinking);编写测试 -- [ ] 1.6 创建 `internal/conversion/canonical/extended.go`:定义扩展层 Canonical Models(CanonicalModelList, CanonicalModel, CanonicalModelInfo, CanonicalEmbeddingRequest, CanonicalEmbeddingResponse, CanonicalRerankRequest, CanonicalRerankResponse);编写测试 +- [x] 1.1 创建 `internal/conversion/errors.go`:定义 ConversionError 结构体(Code, Message, ClientProtocol, ProviderProtocol, InterfaceType, Details, Cause)和 ErrorCode 枚举(INVALID_INPUT, MISSING_REQUIRED_FIELD, INCOMPATIBLE_FEATURE, FIELD_MAPPING_FAILURE, TOOL_CALL_PARSE_ERROR, JSON_PARSE_ERROR, STREAM_STATE_ERROR, UTF8_DECODE_ERROR, PROTOCOL_CONSTRAINT_VIOLATION, ENCODING_FAILURE, INTERFACE_NOT_SUPPORTED),实现 error 接口 +- [x] 1.2 创建 `internal/conversion/interface.go`:定义 InterfaceType 枚举(CHAT, MODELS, MODEL_INFO, EMBEDDINGS, RERANK) +- [x] 1.3 创建 `internal/conversion/provider.go`:定义 TargetProvider 结构体(BaseURL, APIKey, ModelName, AdapterConfig map[string]any);编写测试 +- [x] 1.4 创建 `internal/conversion/canonical/types.go`:定义 CanonicalRequest(model, system, messages, tools, tool_choice, parameters, thinking, stream, user_id, output_format, parallel_tool_use)、CanonicalMessage(role 枚举: system/user/assistant/tool, content []ContentBlock)、ContentBlock(使用 type 字段的 discriminated union:text/tool_use/tool_result/thinking,ToolInput 使用 json.RawMessage)、CanonicalTool(name, description, input_schema)、ToolChoice 联合体(auto/none/any/tool+name)、RequestParameters(max_tokens, temperature, top_p, top_k, frequency_penalty, presence_penalty, stop_sequences)、ThinkingConfig(type: enabled/disabled/adaptive, budget_tokens, effort)、OutputFormat(json_object/json_schema+schema/text)、CanonicalResponse(id, model, content, stop_reason 枚举, usage)、CanonicalUsage(input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens, reasoning_tokens)、SystemBlock(text);编写构造和序列化测试 +- [x] 1.5 创建 `internal/conversion/canonical/stream.go`:定义 CanonicalStreamEvent 联合体(message_start, content_block_start, content_block_delta, content_block_stop, message_delta, message_stop, error, ping)及各事件的具体结构(MessageStartEvent 含 message{id,model,usage}、ContentBlockStartEvent 含 index 和 content_block、ContentBlockDeltaEvent 含 index 和 delta、ContentBlockStopEvent 含 index、MessageDeltaEvent 含 delta{stop_reason} 和 usage、MessageStopEvent、ErrorEvent、PingEvent),delta 联合体(text_delta, input_json_delta, thinking_delta),content_block 联合体(text, tool_use, thinking);编写测试 +- [x] 1.6 创建 `internal/conversion/canonical/extended.go`:定义扩展层 Canonical Models(CanonicalModelList, CanonicalModel, CanonicalModelInfo, CanonicalEmbeddingRequest, CanonicalEmbeddingResponse, CanonicalRerankRequest, CanonicalRerankResponse);编写测试 ## 2. 接口定义层 — Adapter、Stream、Middleware 接口 -- [ ] 2.1 创建 `internal/conversion/adapter.go`:定义 ProtocolAdapter 接口(protocolName, protocolVersion, supportsPassthrough, detectInterfaceType, buildUrl, buildHeaders, supportsInterface, decodeRequest, encodeRequest, decodeResponse, encodeResponse, createStreamDecoder, createStreamEncoder, encodeError, 扩展层编解码方法:decodeModelsResponse/encodeModelsResponse/decodeModelInfoResponse/encodeModelInfoResponse/decodeEmbeddingRequest/encodeEmbeddingRequest/decodeEmbeddingResponse/encodeEmbeddingResponse/decodeRerankRequest/encodeRerankRequest/decodeRerankResponse/encodeRerankResponse),定义 AdapterRegistry 接口(register, get, listProtocols)和 memoryRegistry 实现(sync.RWMutex 保护的 map);编写 Registry 注册/查询/重复注册测试 -- [ ] 2.2 创建 `internal/conversion/stream.go`:定义 StreamDecoder 接口(processChunk(rawChunk []byte) []CanonicalStreamEvent, flush() []CanonicalStreamEvent)、StreamEncoder 接口(encodeEvent(event CanonicalStreamEvent) [][]byte, flush() [][]byte)、StreamConverter 接口(processChunk(rawChunk []byte) [][]byte, flush() [][]byte)、PassthroughStreamConverter 实现(直接传递原始字节)、CanonicalStreamConverter 实现(组合 StreamDecoder + MiddlewareChain + StreamEncoder,processChunk 内部调用 decoder → middleware → encoder 管道);编写 PassthroughStreamConverter 测试 -- [ ] 2.3 创建 `internal/conversion/middleware.go`:定义 ConversionMiddleware 接口(intercept(canonical, clientProtocol, providerProtocol, context) (CanonicalRequest, error) 和可选的 interceptStreamEvent(event, clientProtocol, providerProtocol, context) (CanonicalStreamEvent, error))、ConversionContext 结构体(conversionId, interfaceType, timestamp, metadata)、MiddlewareChain 结构体(按注册顺序链式执行,任一返回错误则中断后续);编写链式执行和中断测试 +- [x] 2.1 创建 `internal/conversion/adapter.go`:定义 ProtocolAdapter 接口(protocolName, protocolVersion, supportsPassthrough, detectInterfaceType, buildUrl, buildHeaders, supportsInterface, decodeRequest, encodeRequest, decodeResponse, encodeResponse, createStreamDecoder, createStreamEncoder, encodeError, 扩展层编解码方法:decodeModelsResponse/encodeModelsResponse/decodeModelInfoResponse/encodeModelInfoResponse/decodeEmbeddingRequest/encodeEmbeddingRequest/decodeEmbeddingResponse/encodeEmbeddingResponse/decodeRerankRequest/encodeRerankRequest/decodeRerankResponse/encodeRerankResponse),定义 AdapterRegistry 接口(register, get, listProtocols)和 memoryRegistry 实现(sync.RWMutex 保护的 map);编写 Registry 注册/查询/重复注册测试 +- [x] 2.2 创建 `internal/conversion/stream.go`:定义 StreamDecoder 接口(processChunk(rawChunk []byte) []CanonicalStreamEvent, flush() []CanonicalStreamEvent)、StreamEncoder 接口(encodeEvent(event CanonicalStreamEvent) [][]byte, flush() [][]byte)、StreamConverter 接口(processChunk(rawChunk []byte) [][]byte, flush() [][]byte)、PassthroughStreamConverter 实现(直接传递原始字节)、CanonicalStreamConverter 实现(组合 StreamDecoder + MiddlewareChain + StreamEncoder,processChunk 内部调用 decoder → middleware → encoder 管道);编写 PassthroughStreamConverter 测试 +- [x] 2.3 创建 `internal/conversion/middleware.go`:定义 ConversionMiddleware 接口(intercept(canonical, clientProtocol, providerProtocol, context) (CanonicalRequest, error) 和可选的 interceptStreamEvent(event, clientProtocol, providerProtocol, context) (CanonicalStreamEvent, error))、ConversionContext 结构体(conversionId, interfaceType, timestamp, metadata)、MiddlewareChain 结构体(按注册顺序链式执行,任一返回错误则中断后续);编写链式执行和中断测试 ## 3. 引擎层 — ConversionEngine 门面 -- [ ] 3.1 创建 `internal/conversion/engine.go`:定义 HTTPRequestSpec(URL, Method string, Headers map[string]string, Body []byte)、HTTPResponseSpec(StatusCode int, Headers map[string]string, Body []byte)、ConversionEngine struct(registry, middlewareChain);实现 registerAdapter、use、isPassthrough、convertHttpRequest(接口识别 → 透传判断 → clientAdapter.decode → middleware → providerAdapter.encode → providerAdapter.buildUrl + buildHeaders)、convertHttpResponse(透传判断 → providerAdapter.decodeResponse → clientAdapter.encodeResponse)、createStreamConverter(透传 → PassthroughStreamConverter,否则 → CanonicalStreamConverter)、内部 convertBody 分发(CHAT 走深度转换,扩展层走轻量映射,默认透传);编写集成测试:使用 mock adapter 测试跨协议转换、同协议透传、未知接口透传 +- [x] 3.1 创建 `internal/conversion/engine.go`:定义 HTTPRequestSpec(URL, Method string, Headers map[string]string, Body []byte)、HTTPResponseSpec(StatusCode int, Headers map[string]string, Body []byte)、ConversionEngine struct(registry, middlewareChain);实现 registerAdapter、use、isPassthrough、convertHttpRequest(接口识别 → 透传判断 → clientAdapter.decode → middleware → providerAdapter.encode → providerAdapter.buildUrl + buildHeaders)、convertHttpResponse(透传判断 → providerAdapter.decodeResponse → clientAdapter.encodeResponse)、createStreamConverter(透传 → PassthroughStreamConverter,否则 → CanonicalStreamConverter)、内部 convertBody 分发(CHAT 走深度转换,扩展层走轻量映射,默认透传);编写集成测试:使用 mock adapter 测试跨协议转换、同协议透传、未知接口透传 ## 4. OpenAI Adapter 实现 -- [ ] 4.1 创建 `internal/conversion/openai/types.go`:从旧 `internal/protocol/openai/types.go` 迁移 OpenAI 线路格式类型,补全缺失字段(developer role, custom tools, reasoning_effort, reasoning_content, max_completion_tokens, parallel_tool_calls, response_format 的 json_schema 类型, stream_options, 废弃的 functions/function_call);编写序列化测试 -- [ ] 4.2 创建 `internal/conversion/openai/decoder.go`:实现 decodeRequest(对照 conversion_openai.md §4.1:decodeSystemPrompt 提取 system+developer 消息、decodeMessage 含 tool_calls/refusal/reasoning_content 解码、tool 消息 tool_call_id→tool_use_id、decodeTools 含 function+custom 类型、decodeToolChoice 含 required→any/allowed_tools 降级、decodeParameters 含 max_completion_tokens 优先、decodeOutputFormat、decodeThinking 含 reasoning_effort→ThinkingConfig、废弃字段 functions→tools 兼容)、decodeResponse(§5.2:content/refusal/reasoning_content/tool_calls 解码、finish_reason 映射表、usage 映射含 cached_tokens/reasoning_tokens)、扩展层 decode(decodeModelsResponse、decodeEmbeddingRequest/Response、decodeRerankRequest/Response);编写完整测试覆盖每类消息和字段映射 -- [ ] 4.3 创建 `internal/conversion/openai/encoder.go`:实现 encodeRequest(对照 conversion_openai.md §4.2:provider.model_name 覆盖、system 注入到 messages[0]、encodeMessage 含 tool_calls 编码到 message 顶层、角色交替合并、encodeTools 含 function 包装、encodeToolChoice 含 any→required、encodeParameters 含 max_completion_tokens、encodeOutputFormat、encodeThinking 含 disabled→"none")、encodeResponse(§5.3:text→content、tool_use→tool_calls、thinking→reasoning_content、finish_reason 反向映射、usage 编码含 prompt_tokens_details)、扩展层 encode(encodeModelsResponse、encodeEmbeddingRequest/Response、encodeRerankRequest/Response);编写完整测试 -- [ ] 4.4 创建 `internal/conversion/openai/adapter.go`:实现 OpenAI ProtocolAdapter(protocolName→"openai"、supportsPassthrough→true、detectInterfaceType 根据正则匹配识别 /v1/chat/completions→CHAT、/v1/models→MODELS 等、buildHeaders 含 Authorization+Content-Type+OpenAI-Organization、buildUrl 按接口类型映射、supportsInterface 对 CHAT/MODELS/MODEL_INFO/EMBEDDINGS/RERANK 返回 true、encodeError 含 ErrorCode→OpenAI 错误类型映射),组合 decoder 和 encoder 方法;编写测试覆盖所有路径模式和边界情况 -- [ ] 4.5 创建 `internal/conversion/openai/stream_decoder.go`:实现 OpenAIStreamDecoder(对照 conversion_openai.md §6.2-§6.3:processChunk 解析 SSE data 行,维护状态机 messageStarted/openBlocks/toolCallIdMap/toolCallNameMap/toolCallArguments/textBlockStarted/thinkingBlockStarted/utf8Remainder/accumulatedUsage,首个 chunk→MessageStartEvent,delta.content→text block 生命周期,delta.tool_calls→tool_use block 生命周期含索引映射和参数累积,delta.reasoning_content→thinking block(非标准),delta.refusal→text block,finish_reason→关闭所有 open blocks + MessageDeltaEvent + MessageStopEvent,usage chunk→MessageDeltaEvent,[DONE]→flush 关闭);编写测试覆盖每种 delta 类型和边界情况(空 chunk、多 tool_calls、UTF-8 截断) -- [ ] 4.6 创建 `internal/conversion/openai/stream_encoder.go`:实现 OpenAIStreamEncoder(对照 conversion_openai.md §6.4:encodeEvent,ContentBlockStart 缓冲策略等待首次 ContentBlockDelta 合并输出,tool_use id/name 在首次 delta 时合并编码,text_delta 直接输出 data: {choices:[{delta:{content}}]},input_json_delta 含 tool_calls 数组编码,thinking_delta 含 reasoning_content 字段,MessageStartEvent→{choices:[{delta:{role:"assistant"}}]},MessageDeltaEvent→{choices:[{delta:{},finish_reason}]},MessageStopEvent→[DONE],PingEvent/ErrorEvent 丢弃,flush 输出缓冲区);编写测试 +- [x] 4.1 创建 `internal/conversion/openai/types.go`:对照 `docs/conversion_openai.md` 全新定义 OpenAI 线路格式类型(不沿用旧 `internal/protocol/openai/types.go`),包含完整字段(developer role, custom tools, reasoning_effort, reasoning_content, max_completion_tokens, parallel_tool_calls, response_format 的 json_schema 类型, stream_options, 废弃的 functions/function_call);编写序列化测试 +- [x] 4.2 创建 `internal/conversion/openai/decoder.go`:实现 decodeRequest(对照 conversion_openai.md §4.1:decodeSystemPrompt 提取 system+developer 消息、decodeMessage 含 tool_calls/refusal/reasoning_content 解码、tool 消息 tool_call_id→tool_use_id、decodeTools 含 function+custom 类型、decodeToolChoice 含 required→any/allowed_tools 降级、decodeParameters 含 max_completion_tokens 优先、decodeOutputFormat、decodeThinking 含 reasoning_effort→ThinkingConfig、废弃字段 functions→tools 兼容)、decodeResponse(§5.2:content/refusal/reasoning_content/tool_calls 解码、finish_reason 映射表、usage 映射含 cached_tokens/reasoning_tokens)、扩展层 decode(decodeModelsResponse、decodeEmbeddingRequest/Response、decodeRerankRequest/Response);编写完整测试覆盖每类消息和字段映射 +- [x] 4.3 创建 `internal/conversion/openai/encoder.go`:实现 encodeRequest(对照 conversion_openai.md §4.2:provider.model_name 覆盖、system 注入到 messages[0]、encodeMessage 含 tool_calls 编码到 message 顶层、角色交替合并、encodeTools 含 function 包装、encodeToolChoice 含 any→required、encodeParameters 含 max_completion_tokens、encodeOutputFormat、encodeThinking 含 disabled→"none")、encodeResponse(§5.3:text→content、tool_use→tool_calls、thinking→reasoning_content、finish_reason 反向映射、usage 编码含 prompt_tokens_details)、扩展层 encode(encodeModelsResponse、encodeEmbeddingRequest/Response、encodeRerankRequest/Response);编写完整测试 +- [x] 4.4 创建 `internal/conversion/openai/adapter.go`:实现 OpenAI ProtocolAdapter(protocolName→"openai"、supportsPassthrough→true、detectInterfaceType 根据正则匹配识别 /v1/chat/completions→CHAT、/v1/models→MODELS 等、buildHeaders 含 Authorization+Content-Type+OpenAI-Organization、buildUrl 按接口类型映射、supportsInterface 对 CHAT/MODELS/MODEL_INFO/EMBEDDINGS/RERANK 返回 true、encodeError 含 ErrorCode→OpenAI 错误类型映射),组合 decoder 和 encoder 方法;编写测试覆盖所有路径模式和边界情况 +- [x] 4.5 创建 `internal/conversion/openai/stream_decoder.go`:实现 OpenAIStreamDecoder(对照 conversion_openai.md §6.2-§6.3:processChunk 解析 SSE data 行,维护状态机 messageStarted/openBlocks/toolCallIdMap/toolCallNameMap/toolCallArguments/textBlockStarted/thinkingBlockStarted/utf8Remainder/accumulatedUsage,首个 chunk→MessageStartEvent,delta.content→text block 生命周期,delta.tool_calls→tool_use block 生命周期含索引映射和参数累积,delta.reasoning_content→thinking block(非标准),delta.refusal→text block,finish_reason→关闭所有 open blocks + MessageDeltaEvent + MessageStopEvent,usage chunk→MessageDeltaEvent,[DONE]→flush 关闭);编写测试覆盖每种 delta 类型和边界情况(空 chunk、多 tool_calls、UTF-8 截断) +- [x] 4.6 创建 `internal/conversion/openai/stream_encoder.go`:实现 OpenAIStreamEncoder(对照 conversion_openai.md §6.4:encodeEvent,ContentBlockStart 缓冲策略等待首次 ContentBlockDelta 合并输出,tool_use id/name 在首次 delta 时合并编码,text_delta 直接输出 data: {choices:[{delta:{content}}]},input_json_delta 含 tool_calls 数组编码,thinking_delta 含 reasoning_content 字段,MessageStartEvent→{choices:[{delta:{role:"assistant"}}]},MessageDeltaEvent→{choices:[{delta:{},finish_reason}]},MessageStopEvent→[DONE],PingEvent/ErrorEvent 丢弃,flush 输出缓冲区);编写测试 ## 5. Anthropic Adapter 实现(与 Layer 4 并行) -- [ ] 5.1 创建 `internal/conversion/anthropic/types.go`:从旧 `internal/protocol/anthropic/types.go` 迁移 Anthropic 线路格式类型,补全缺失字段(thinking.type 含 adaptive、output_config.format/effort、disable_parallel_tool_use、metadata.user_id、redacted_thinking、pause_turn/refusal stop_reason、stop_details、container、cache_control);编写序列化测试 -- [ ] 5.2 创建 `internal/conversion/anthropic/decoder.go`:实现 decodeRequest(对照 conversion_anthropic.md §4.1:decodeSystem 从顶层 system 提取、decodeMessage 含 tool_result 从 user 消息拆分为独立 tool 角色消息、参数直接映射含 top_k、decodeThinking 含 enabled/disabled/adaptive 三种类型、decodeOutputFormat 仅支持 json_schema、公共字段提取含 metadata.user_id/disable_parallel_tool_use 反转/output_config.effort、协议特有字段 redacted_thinking 丢弃/cache_control 忽略)、decodeResponse(§5.2:text/tool_use/thinking 块解码、redacted_thinking 丢弃、stop_reason 映射含 pause_turn/refusal、usage 映射含 cache_read_input_tokens/cache_creation_input_tokens)、扩展层 decode(decodeModelsResponse 含 RFC3339→Unix 时间戳转换、decodeModelInfoResponse);编写完整测试覆盖角色拆分、thinking 三种类型、时间戳转换 -- [ ] 5.3 创建 `internal/conversion/anthropic/encoder.go`:实现 encodeRequest(对照 conversion_anthropic.md §4.2:provider.model_name 覆盖、system 注入为顶层字段、encodeMessages 含 tool→user 合并(优先合并到相邻 user 消息)、首消息 user 保证(自动注入空 user)、角色交替合并、encodeThinkingConfig 含 enabled/disabled/adaptive、encodeOutputFormat 含 json_object→空 schema 降级/text 丢弃、公共字段编码含 metadata.user_id/disable_parallel_tool_use 反转/output_config、参数编码含 max_tokens 必填/top_k 直接映射)、encodeResponse(§5.3:text/tool_use/thinking 块直接编码、stop_reason 映射含 content_filter→end_turn 降级、usage 编码含 cache_read_input_tokens/cache_creation_input_tokens)、扩展层 encode(encodeModelsResponse 含 Unix→RFC3339 转换和 has_more/first_id/last_id 字段、encodeModelInfoResponse);编写完整测试覆盖角色合并、首消息注入、降级处理 -- [ ] 5.4 创建 `internal/conversion/anthropic/adapter.go`:实现 Anthropic ProtocolAdapter(protocolName→"anthropic"、supportsPassthrough→true、detectInterfaceType 根据正则匹配识别 /v1/messages→CHAT、/v1/models→MODELS 等、buildHeaders 含 x-api-key + anthropic-version + anthropic-beta + Content-Type、buildUrl 按接口类型映射、supportsInterface 对 CHAT/MODELS/MODEL_INFO 返回 true 对 EMBEDDINGS/RERANK 返回 false、encodeError 返回 {type:"error",error:{type,message}});编写测试覆盖所有路径模式和边界情况 -- [ ] 5.5 创建 `internal/conversion/anthropic/stream_decoder.go`:实现 AnthropicStreamDecoder(对照 conversion_anthropic.md §6.2-§6.3:解析命名 SSE 事件 event: message_start/data: {...},1:1 映射到 CanonicalStreamEvent,维护状态 messageStarted/redactedBlocks/utf8Remainder/accumulatedUsage,redacted_thinking 检测后加入 redactedBlocks 并丢弃后续 delta/stop,citations_delta/signature_delta 直接丢弃,server_tool_use 等服务端工具块丢弃,UTF-8 跨 chunk 安全处理);编写测试覆盖所有事件类型和 redacted_thinking 丢弃 -- [ ] 5.6 创建 `internal/conversion/anthropic/stream_encoder.go`:实现 AnthropicStreamEncoder(对照 conversion_anthropic.md §6.4:直接映射无缓冲,每个 CanonicalStreamEvent 直接编码为对应的 Anthropic 命名 SSE 事件,格式 event: ``\ndata: ``\n\n,delta 编码 text_delta/input_json_delta/thinking_delta 直接映射);编写测试 +- [x] 5.1 创建 `internal/conversion/anthropic/types.go`:对照 `docs/conversion_anthropic.md` 全新定义 Anthropic 线路格式类型(不沿用旧 `internal/protocol/anthropic/types.go`),包含完整字段(thinking.type 含 adaptive、output_config.format/effort、disable_parallel_tool_use、metadata.user_id、redacted_thinking、pause_turn/refusal stop_reason、stop_details、container、cache_control);编写序列化测试 +- [x] 5.2 创建 `internal/conversion/anthropic/decoder.go`:实现 decodeRequest(对照 conversion_anthropic.md §4.1:decodeSystem 从顶层 system 提取、decodeMessage 含 tool_result 从 user 消息拆分为独立 tool 角色消息、参数直接映射含 top_k、decodeThinking 含 enabled/disabled/adaptive 三种类型、decodeOutputFormat 仅支持 json_schema、公共字段提取含 metadata.user_id/disable_parallel_tool_use 反转/output_config.effort、协议特有字段 redacted_thinking 丢弃/cache_control 忽略)、decodeResponse(§5.2:text/tool_use/thinking 块解码、redacted_thinking 丢弃、stop_reason 映射含 pause_turn/refusal、usage 映射含 cache_read_input_tokens/cache_creation_input_tokens)、扩展层 decode(decodeModelsResponse 含 RFC3339→Unix 时间戳转换、decodeModelInfoResponse);编写完整测试覆盖角色拆分、thinking 三种类型、时间戳转换 +- [x] 5.3 创建 `internal/conversion/anthropic/encoder.go`:实现 encodeRequest(对照 conversion_anthropic.md §4.2:provider.model_name 覆盖、system 注入为顶层字段、encodeMessages 含 tool→user 合并(优先合并到相邻 user 消息)、首消息 user 保证(自动注入空 user)、角色交替合并、encodeThinkingConfig 含 enabled/disabled/adaptive、encodeOutputFormat 含 json_object→空 schema 降级/text 丢弃、公共字段编码含 metadata.user_id/disable_parallel_tool_use 反转/output_config、参数编码含 max_tokens 必填/top_k 直接映射)、encodeResponse(§5.3:text/tool_use/thinking 块直接编码、stop_reason 映射含 content_filter→end_turn 降级、usage 编码含 cache_read_input_tokens/cache_creation_input_tokens)、扩展层 encode(encodeModelsResponse 含 Unix→RFC3339 转换和 has_more/first_id/last_id 字段、encodeModelInfoResponse);编写完整测试覆盖角色合并、首消息注入、降级处理 +- [x] 5.4 创建 `internal/conversion/anthropic/adapter.go`:实现 Anthropic ProtocolAdapter(protocolName→"anthropic"、supportsPassthrough→true、detectInterfaceType 根据正则匹配识别 /v1/messages→CHAT、/v1/models→MODELS 等、buildHeaders 含 x-api-key + anthropic-version + anthropic-beta + Content-Type、buildUrl 按接口类型映射、supportsInterface 对 CHAT/MODELS/MODEL_INFO 返回 true 对 EMBEDDINGS/RERANK 返回 false、encodeError 返回 {type:"error",error:{type,message}});编写测试覆盖所有路径模式和边界情况 +- [x] 5.5 创建 `internal/conversion/anthropic/stream_decoder.go`:实现 AnthropicStreamDecoder(对照 conversion_anthropic.md §6.2-§6.3:解析命名 SSE 事件 event: message_start/data: {...},1:1 映射到 CanonicalStreamEvent,维护状态 messageStarted/redactedBlocks/utf8Remainder/accumulatedUsage,redacted_thinking 检测后加入 redactedBlocks 并丢弃后续 delta/stop,citations_delta/signature_delta 直接丢弃,server_tool_use 等服务端工具块丢弃,UTF-8 跨 chunk 安全处理);编写测试覆盖所有事件类型和 redacted_thinking 丢弃 +- [x] 5.6 创建 `internal/conversion/anthropic/stream_encoder.go`:实现 AnthropicStreamEncoder(对照 conversion_anthropic.md §6.4:直接映射无缓冲,每个 CanonicalStreamEvent 直接编码为对应的 Anthropic 命名 SSE 事件,格式 event: ``\ndata: ``\n\n,delta 编码 text_delta/input_json_delta/thinking_delta 直接映射);编写测试 ## 6. 基础设施改造 — Provider、Handler、Domain -- [ ] 6.1 修改 `internal/domain/provider.go`:Provider 结构体新增 Protocol string 字段;修改 `internal/config/models.go`:GORM Provider 模型同步新增 Protocol 字段(gorm:"column:protocol;default:'openai'");修改 `internal/repository/` 中 toDomainProvider 和 toConfigProvider 转换函数同步 Protocol 字段;修改 `internal/handler/provider_handler.go`:CreateProvider 和 UpdateProvider 的请求结构体新增 Protocol 字段(可选,默认 "openai"),创建/更新 Provider 时赋值 Protocol 字段,List/Get 响应中包含 Protocol 字段;更新 `internal/service/service_test.go` 中所有创建测试 Provider 的地方补充 Protocol 字段;更新 `internal/handler/handler_test.go` 中 Provider CRUD 测试的请求体补充 Protocol 字段;创建数据库迁移文件 `backend/migrations/YYYYMMDDHHMMSS_add_provider_protocol.sql`:ALTER TABLE providers ADD COLUMN protocol TEXT DEFAULT 'openai' -- [ ] 6.2 重写 `internal/provider/client.go`:定义 HTTPRequestSpec 和 HTTPResponseSpec(或引用 conversion 包的定义),简化 ProviderClient 接口为 Send(ctx, HTTPRequestSpec) → (*HTTPResponseSpec, error) 和 SendStream(ctx, HTTPRequestSpec) → (<-chan StreamEvent, error),移除所有 openai.Adapter 硬编码依赖,Send 方法直接使用 http.NewRequest + spec.URL/Headers/Body,SendStream 保留现有 readStream goroutine 逻辑但输入改为 HTTPRequestSpec;重写 `provider/client_test.go`:删除所有基于 openai.ChatCompletionRequest 的旧测试用例,基于 HTTPRequestSpec 重写成功/失败/流式测试用例,使用 httptest.Server 验证请求构建和响应解析 -- [ ] 6.3 创建 `internal/handler/proxy_handler.go`:实现 ProxyHandler struct(依赖 ConversionEngine、ProviderClient、RoutingService、StatsService),实现 HandleProxy(w, r) 方法:从 URL 提取 clientProtocol(仅支持 `/{protocol}/v1/...` 前缀路由,不支持旧路由)、解析请求体 JSON、调用 RoutingService.Route(modelName) 获取路由结果(含 Provider.Protocol 作为 providerProtocol)、构建 TargetProvider、调用 engine.convertHttpRequest、调用 providerClient.Send/SendStream、调用 engine.convertHttpResponse、设置响应 Content-Type 和状态码、流式处理设置 text/event-stream 并用 StreamConverter 逐块转换写入、错误处理使用 clientAdapter.encodeError、异步调用 StatsService.Record;编写测试使用 httptest + mock engine/client/service -- [ ] 6.4 修改 `cmd/server/main.go`:创建 AdapterRegistry 并注册 OpenAI 和 Anthropic Adapter、创建 ConversionEngine(注入 registry)、创建 ProxyHandler(注入 engine + providerClient + routingService + statsService)、配置 Gin 路由:新增 `/{protocol}/v1/{path:*}` → ProxyHandler.HandleProxy,删除旧路由 `/v1/chat/completions` 和 `/v1/messages`,移除旧的 OpenAIHandler 和 AnthropicHandler 的路由注册,移除旧的 Adapter 创建代码 +- [x] 6.1 修改 `internal/domain/provider.go`:Provider 结构体新增 Protocol string 字段;修改 `internal/config/models.go`:GORM Provider 模型同步新增 Protocol 字段(gorm:"column:protocol;default:'openai'");修改 `internal/repository/` 中 toDomainProvider 和 toConfigProvider 转换函数同步 Protocol 字段;修改 `internal/handler/provider_handler.go`:CreateProvider 和 UpdateProvider 的请求结构体新增 Protocol 字段(可选,默认 "openai"),创建/更新 Provider 时赋值 Protocol 字段,List/Get 响应中包含 Protocol 字段;更新 `internal/service/service_test.go` 中所有创建测试 Provider 的地方补充 Protocol 字段;更新 `internal/handler/handler_test.go` 中 Provider CRUD 测试的请求体补充 Protocol 字段;创建数据库迁移文件 `backend/migrations/YYYYMMDDHHMMSS_add_provider_protocol.sql`:ALTER TABLE providers ADD COLUMN protocol TEXT DEFAULT 'openai' +- [x] 6.2 重写 `internal/provider/client.go`:定义 HTTPRequestSpec 和 HTTPResponseSpec(或引用 conversion 包的定义),简化 ProviderClient 接口为 Send(ctx, HTTPRequestSpec) → (*HTTPResponseSpec, error) 和 SendStream(ctx, HTTPRequestSpec) → (<-chan StreamEvent, error),移除所有旧协议硬编码依赖,Send 方法直接使用 http.NewRequest + spec.URL/Headers/Body,SendStream 保留现有 readStream goroutine 逻辑但输入改为 HTTPRequestSpec;重写 `provider/client_test.go`:删除所有基于旧协议类型的测试用例,基于 HTTPRequestSpec 重写成功/失败/流式测试用例,使用 httptest.Server 验证请求构建和响应解析 +- [x] 6.3 创建 `internal/handler/proxy_handler.go`:实现 ProxyHandler struct(依赖 ConversionEngine、ProviderClient、RoutingService、StatsService),实现 HandleProxy(w, r) 方法:从 URL 提取 clientProtocol(仅支持 `/{protocol}/v1/...` 前缀路由,不支持旧路由)、解析请求体 JSON、调用 RoutingService.Route(modelName) 获取路由结果(含 Provider.Protocol 作为 providerProtocol)、构建 TargetProvider、调用 engine.convertHttpRequest、调用 providerClient.Send/SendStream、调用 engine.convertHttpResponse、设置响应 Content-Type 和状态码、流式处理设置 text/event-stream 并用 StreamConverter 逐块转换写入、错误处理使用 clientAdapter.encodeError、异步调用 StatsService.Record;编写测试使用 httptest + mock engine/client/service +- [x] 6.4 修改 `cmd/server/main.go`:创建 AdapterRegistry 并注册 OpenAI 和 Anthropic Adapter、创建 ConversionEngine(注入 registry)、创建 ProxyHandler(注入 engine + providerClient + routingService + statsService)、配置 Gin 路由:新增 `/{protocol}/v1/{path:*}` → ProxyHandler.HandleProxy,删除旧路由 `/v1/chat/completions` 和 `/v1/messages`,移除旧的 OpenAIHandler 和 AnthropicHandler 的路由注册,删除旧 Adapter 创建代码 ## 7. 清理和文档 -- [ ] 7.1 删除旧代码:删除 `internal/protocol/openai/` 目录(types.go, adapter.go, adapter_test.go)、删除 `internal/protocol/anthropic/` 目录(types.go, converter.go, converter_test.go, stream_converter.go, stream_converter_test.go)、删除 `internal/handler/openai_handler.go` 和 `internal/handler/anthropic_handler.go`、删除 `internal/handler/handler_test.go` 中旧 OpenAI/Anthropic handler 测试用例和旧 `mockProviderClient`(基于 openai.ChatCompletionRequest 的签名)、重写 `handler_test.go` 为 ProxyHandler 测试(基于新 ProviderClient 接口和 ConversionEngine mock)、删除 `internal/protocol/` 空目录、确认所有编译通过且无残留 import -- [ ] 7.2 更新 `README.md`:更新项目结构说明(新增 internal/conversion/、删除 internal/protocol/)、更新 API 接口说明(代理接口变更:`/{protocol}/v1/...`,移除旧路由 `/v1/chat/completions` 和 `/v1/messages`)、更新配置说明(Provider 新增 protocol 字段) -- [ ] 7.3 端到端测试:在 `backend/tests/integration/` 中新增 `conversion_test.go`,使用 httptest mock 上游服务器验证完整请求流:OpenAI→OpenAI 同协议透传、Anthropic→Anthropic 同协议透传、OpenAI→Anthropic 跨协议非流式、Anthropic→OpenAI 跨协议非流式、4 种方向的流式转换(含 tool_calls 和 thinking)、Models 接口跨协议转换、错误响应格式验证(各协议格式)、旧路由 `/v1/chat/completions` 和 `/v1/messages` 返回 404;复用 `tests/helpers.go` 中的测试数据库和 Provider/Model 创建辅助函数 +- [x] 7.1 删除旧代码:删除 `internal/protocol/openai/` 目录(types.go, adapter.go, adapter_test.go)、删除 `internal/protocol/anthropic/` 目录(types.go, converter.go, converter_test.go, stream_converter.go, stream_converter_test.go)、删除 `internal/handler/openai_handler.go` 和 `internal/handler/anthropic_handler.go`、删除 `internal/handler/handler_test.go` 中旧 OpenAI/Anthropic handler 测试用例和旧 `mockProviderClient`(基于旧协议类型的签名)、重写 `handler_test.go` 为 ProxyHandler 测试(基于新 ProviderClient 接口和 ConversionEngine mock)、删除 `internal/protocol/` 空目录、确认所有编译通过且无残留 import +- [x] 7.2 更新 `README.md`:更新项目结构说明(新增 internal/conversion/、删除 internal/protocol/)、更新 API 接口说明(代理接口变更:`/{protocol}/v1/...`,移除旧路由 `/v1/chat/completions` 和 `/v1/messages`)、更新配置说明(Provider 新增 protocol 字段) +- [x] 7.3 端到端测试:在 `backend/tests/integration/` 中新增 `conversion_test.go`,使用 httptest mock 上游服务器验证完整请求流:OpenAI→OpenAI 同协议透传、Anthropic→Anthropic 同协议透传、OpenAI→Anthropic 跨协议非流式、Anthropic→OpenAI 跨协议非流式、4 种方向的流式转换(含 tool_calls 和 thinking)、Models 接口跨协议转换、错误响应格式验证(各协议格式)、旧路由 `/v1/chat/completions` 和 `/v1/messages` 返回 404;复用 `tests/helpers.go` 中的测试数据库和 Provider/Model 创建辅助函数