1
0
Files
nex/backend/internal/conversion/openai/decoder.go
lanyuanxiaoyao 1dac347d3b refactor: 实现 ConversionEngine 协议转换引擎,替代旧 protocol 包
引入 Canonical Model 和 ProtocolAdapter 架构,支持 OpenAI/Anthropic 协议间
无缝转换,统一 ProxyHandler 替代分散的 OpenAI/Anthropic Handler,简化
ProviderClient 为协议无关的 HTTP 发送器,Provider 新增 protocol 字段。
2026-04-20 00:36:27 +08:00

670 lines
18 KiB
Go

package openai
import (
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical"
)
// decodeRequest 将 OpenAI 请求解码为 Canonical 请求
func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
var req ChatCompletionRequest
if err := json.Unmarshal(body, &req); err != nil {
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 OpenAI 请求失败").WithCause(err)
}
if strings.TrimSpace(req.Model) == "" {
return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "model 字段不能为空")
}
if len(req.Messages) == 0 {
return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "messages 字段不能为空")
}
// 废弃字段兼容
decodeDeprecatedFields(&req)
system, messages := decodeSystemPrompt(req.Messages)
var canonicalMsgs []canonical.CanonicalMessage
for _, msg := range messages {
decoded, err := decodeMessage(msg)
if err != nil {
return nil, err
}
canonicalMsgs = append(canonicalMsgs, decoded...)
}
tools := decodeTools(req.Tools)
toolChoice := decodeToolChoice(req.ToolChoice)
params := decodeParameters(&req)
outputFormat := decodeOutputFormat(req.ResponseFormat)
thinking := decodeThinking(req.ReasoningEffort)
var parallelToolUse *bool
if req.ParallelToolCalls != nil {
parallelToolUse = req.ParallelToolCalls
}
return &canonical.CanonicalRequest{
Model: req.Model,
System: system,
Messages: canonicalMsgs,
Tools: tools,
ToolChoice: toolChoice,
Parameters: params,
Thinking: thinking,
Stream: req.Stream,
UserID: req.User,
OutputFormat: outputFormat,
ParallelToolUse: parallelToolUse,
}, nil
}
// decodeSystemPrompt 提取 system 和 developer 消息
func decodeSystemPrompt(messages []Message) (any, []Message) {
var systemParts []string
var remaining []Message
for _, msg := range messages {
if msg.Role == "system" || msg.Role == "developer" {
text := extractText(msg.Content)
if text != "" {
systemParts = append(systemParts, text)
}
} else {
remaining = append(remaining, msg)
}
}
if len(systemParts) == 0 {
return nil, remaining
}
return strings.Join(systemParts, "\n\n"), remaining
}
// extractText 从 content 提取文本
func extractText(content any) string {
switch v := content.(type) {
case string:
return v
case []any:
var parts []string
for _, item := range v {
if m, ok := item.(map[string]any); ok {
if t, ok := m["type"].(string); ok && t == "text" {
if text, ok := m["text"].(string); ok {
parts = append(parts, text)
}
}
}
}
return strings.Join(parts, "")
case nil:
return ""
default:
return fmt.Sprintf("%v", v)
}
}
// decodeMessage 解码 OpenAI 消息
func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
switch msg.Role {
case "user":
blocks := decodeUserContent(msg.Content)
return []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: blocks}}, nil
case "assistant":
var blocks []canonical.ContentBlock
// 处理 content
if msg.Content != nil {
switch v := msg.Content.(type) {
case string:
if v != "" {
blocks = append(blocks, canonical.NewTextBlock(v))
}
default:
parts := decodeContentParts(msg.Content)
for _, p := range parts {
if p.Type == "text" {
blocks = append(blocks, canonical.NewTextBlock(p.Text))
} else if p.Type == "refusal" {
blocks = append(blocks, canonical.NewTextBlock(p.Refusal))
}
}
}
}
// refusal 顶层字段
if msg.Refusal != "" {
blocks = append(blocks, canonical.NewTextBlock(msg.Refusal))
}
// reasoning_content 非标准字段
if msg.ReasoningContent != "" {
blocks = append(blocks, canonical.NewThinkingBlock(msg.ReasoningContent))
}
// tool_calls
for _, tc := range msg.ToolCalls {
var input json.RawMessage
if tc.Type == "custom" && tc.Custom != nil {
input = json.RawMessage(fmt.Sprintf("%q", tc.Custom.Input))
} else if tc.Function != nil {
parsed := json.RawMessage(tc.Function.Arguments)
if !json.Valid(parsed) {
parsed = json.RawMessage("{}")
}
input = parsed
} else {
input = json.RawMessage("{}")
}
name := ""
if tc.Function != nil {
name = tc.Function.Name
} else if tc.Custom != nil {
name = tc.Custom.Name
}
blocks = append(blocks, canonical.NewToolUseBlock(tc.ID, name, input))
}
// 已废弃 function_call
if msg.FunctionCall != nil {
input := json.RawMessage(msg.FunctionCall.Arguments)
if !json.Valid(input) {
input = json.RawMessage("{}")
}
blocks = append(blocks, canonical.NewToolUseBlock(generateID(), msg.FunctionCall.Name, input))
}
if len(blocks) == 0 {
blocks = append(blocks, canonical.NewTextBlock(""))
}
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
case "tool":
content := extractText(msg.Content)
isErr := false
block := canonical.ContentBlock{
Type: "tool_result",
ToolUseID: msg.ToolCallID,
Content: json.RawMessage(fmt.Sprintf("%q", content)),
IsError: &isErr,
}
return []canonical.CanonicalMessage{{Role: canonical.RoleTool, Content: []canonical.ContentBlock{block}}}, nil
case "function":
content := extractText(msg.Content)
isErr := false
block := canonical.ContentBlock{
Type: "tool_result",
ToolUseID: msg.Name,
Content: json.RawMessage(fmt.Sprintf("%q", content)),
IsError: &isErr,
}
return []canonical.CanonicalMessage{{Role: canonical.RoleTool, Content: []canonical.ContentBlock{block}}}, nil
}
return nil, nil
}
// decodeUserContent 解码用户内容
func decodeUserContent(content any) []canonical.ContentBlock {
switch v := content.(type) {
case string:
return []canonical.ContentBlock{canonical.NewTextBlock(v)}
case []any:
var blocks []canonical.ContentBlock
for _, item := range v {
if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string)
switch t {
case "text":
text, _ := m["text"].(string)
blocks = append(blocks, canonical.NewTextBlock(text))
case "image_url":
blocks = append(blocks, canonical.ContentBlock{Type: "image"})
case "input_audio":
blocks = append(blocks, canonical.ContentBlock{Type: "audio"})
case "file":
blocks = append(blocks, canonical.ContentBlock{Type: "file"})
}
}
}
if len(blocks) > 0 {
return blocks
}
return []canonical.ContentBlock{canonical.NewTextBlock("")}
case nil:
return []canonical.ContentBlock{canonical.NewTextBlock("")}
default:
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}
}
}
// contentPart 内容部分
type contentPart struct {
Type string
Text string
Refusal string
}
// decodeContentParts 解码内容部分
func decodeContentParts(content any) []contentPart {
parts, ok := content.([]any)
if !ok {
return nil
}
var result []contentPart
for _, item := range parts {
if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string)
switch t {
case "text":
text, _ := m["text"].(string)
result = append(result, contentPart{Type: "text", Text: text})
case "refusal":
refusal, _ := m["refusal"].(string)
result = append(result, contentPart{Type: "refusal", Refusal: refusal})
}
}
}
return result
}
// decodeTools 解码工具定义
func decodeTools(tools []Tool) []canonical.CanonicalTool {
if len(tools) == 0 {
return nil
}
var result []canonical.CanonicalTool
for _, tool := range tools {
if tool.Type == "function" && tool.Function != nil {
result = append(result, canonical.CanonicalTool{
Name: tool.Function.Name,
Description: tool.Function.Description,
InputSchema: tool.Function.Parameters,
})
}
}
if len(result) == 0 {
return nil
}
return result
}
// decodeToolChoice 解码工具选择
func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
if toolChoice == nil {
return nil
}
switch v := toolChoice.(type) {
case string:
switch v {
case "auto":
return canonical.NewToolChoiceAuto()
case "none":
return canonical.NewToolChoiceNone()
case "required":
return canonical.NewToolChoiceAny()
}
case map[string]any:
t, _ := v["type"].(string)
switch t {
case "function":
if fn, ok := v["function"].(map[string]any); ok {
name, _ := fn["name"].(string)
return canonical.NewToolChoiceNamed(name)
}
case "custom":
if custom, ok := v["custom"].(map[string]any); ok {
name, _ := custom["name"].(string)
return canonical.NewToolChoiceNamed(name)
}
case "allowed_tools":
if at, ok := v["allowed_tools"].(map[string]any); ok {
mode, _ := at["mode"].(string)
if mode == "required" {
return canonical.NewToolChoiceAny()
}
return canonical.NewToolChoiceAuto()
}
return canonical.NewToolChoiceAuto()
}
}
return nil
}
// decodeParameters 解码请求参数
func decodeParameters(req *ChatCompletionRequest) canonical.RequestParameters {
params := canonical.RequestParameters{
Temperature: req.Temperature,
TopP: req.TopP,
FrequencyPenalty: req.FrequencyPenalty,
PresencePenalty: req.PresencePenalty,
}
if req.MaxCompletionTokens != nil {
params.MaxTokens = req.MaxCompletionTokens
} else if req.MaxTokens != nil {
params.MaxTokens = req.MaxTokens
}
if req.Stop != nil {
params.StopSequences = normalizeStop(req.Stop)
}
return params
}
// normalizeStop 规范化 stop 参数
func normalizeStop(stop any) []string {
switch v := stop.(type) {
case string:
if v == "" {
return nil
}
return []string{v}
case []any:
var result []string
for _, s := range v {
if str, ok := s.(string); ok && str != "" {
result = append(result, str)
}
}
if len(result) == 0 {
return nil
}
return result
case []string:
return v
}
return nil
}
// decodeOutputFormat 解码输出格式
func decodeOutputFormat(format *ResponseFormat) *canonical.OutputFormat {
if format == nil {
return nil
}
switch format.Type {
case "json_object":
return &canonical.OutputFormat{Type: "json_object"}
case "json_schema":
if format.JSONSchema != nil {
return &canonical.OutputFormat{
Type: "json_schema",
Name: format.JSONSchema.Name,
Schema: format.JSONSchema.Schema,
Strict: format.JSONSchema.Strict,
}
}
return &canonical.OutputFormat{Type: "json_schema"}
case "text":
return nil
}
return nil
}
// decodeThinking 解码推理配置
func decodeThinking(reasoningEffort string) *canonical.ThinkingConfig {
if reasoningEffort == "" {
return nil
}
if reasoningEffort == "none" {
return &canonical.ThinkingConfig{Type: "disabled"}
}
effort := reasoningEffort
if effort == "minimal" {
effort = "low"
}
return &canonical.ThinkingConfig{Type: "enabled", Effort: effort}
}
// decodeDeprecatedFields 废弃字段兼容
func decodeDeprecatedFields(req *ChatCompletionRequest) {
if len(req.Tools) == 0 && len(req.Functions) > 0 {
req.Tools = make([]Tool, len(req.Functions))
for i, f := range req.Functions {
req.Tools[i] = Tool{
Type: "function",
Function: &FunctionDef{
Name: f.Name,
Description: f.Description,
Parameters: f.Parameters,
},
}
}
}
if req.ToolChoice == nil && req.FunctionCall != nil {
switch v := req.FunctionCall.(type) {
case string:
switch v {
case "none":
req.ToolChoice = "none"
case "auto":
req.ToolChoice = "auto"
}
case map[string]any:
if name, ok := v["name"].(string); ok {
req.ToolChoice = map[string]any{
"type": "function",
"function": map[string]any{"name": name},
}
}
}
}
}
// decodeResponse 将 OpenAI 响应解码为 Canonical 响应
func decodeResponse(body []byte) (*canonical.CanonicalResponse, error) {
var resp ChatCompletionResponse
if err := json.Unmarshal(body, &resp); err != nil {
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 OpenAI 响应失败").WithCause(err)
}
if len(resp.Choices) == 0 {
return &canonical.CanonicalResponse{
ID: resp.ID,
Model: resp.Model,
Content: []canonical.ContentBlock{canonical.NewTextBlock("")},
Usage: canonical.CanonicalUsage{},
}, nil
}
choice := resp.Choices[0]
var blocks []canonical.ContentBlock
if choice.Message != nil {
if choice.Message.Content != nil {
text := extractText(choice.Message.Content)
if text != "" {
blocks = append(blocks, canonical.NewTextBlock(text))
}
}
if choice.Message.Refusal != "" {
blocks = append(blocks, canonical.NewTextBlock(choice.Message.Refusal))
}
if choice.Message.ReasoningContent != "" {
blocks = append(blocks, canonical.NewThinkingBlock(choice.Message.ReasoningContent))
}
for _, tc := range choice.Message.ToolCalls {
var input json.RawMessage
name := ""
if tc.Type == "custom" && tc.Custom != nil {
input = json.RawMessage(fmt.Sprintf("%q", tc.Custom.Input))
name = tc.Custom.Name
} else if tc.Function != nil {
input = json.RawMessage(tc.Function.Arguments)
if !json.Valid(input) {
input = json.RawMessage("{}")
}
name = tc.Function.Name
} else {
input = json.RawMessage("{}")
}
blocks = append(blocks, canonical.NewToolUseBlock(tc.ID, name, input))
}
}
if len(blocks) == 0 {
blocks = append(blocks, canonical.NewTextBlock(""))
}
var stopReason *canonical.StopReason
if choice.FinishReason != nil {
sr := mapFinishReason(*choice.FinishReason)
stopReason = &sr
}
return &canonical.CanonicalResponse{
ID: resp.ID,
Model: resp.Model,
Content: blocks,
StopReason: stopReason,
Usage: decodeUsage(resp.Usage),
}, nil
}
// mapFinishReason 映射结束原因
func mapFinishReason(reason string) canonical.StopReason {
switch reason {
case "stop":
return canonical.StopReasonEndTurn
case "length":
return canonical.StopReasonMaxTokens
case "tool_calls":
return canonical.StopReasonToolUse
case "function_call":
return canonical.StopReasonToolUse
case "content_filter":
return canonical.StopReasonContentFilter
default:
return canonical.StopReasonEndTurn
}
}
// decodeUsage 解码用量
func decodeUsage(usage *Usage) canonical.CanonicalUsage {
if usage == nil {
return canonical.CanonicalUsage{}
}
result := canonical.CanonicalUsage{
InputTokens: usage.PromptTokens,
OutputTokens: usage.CompletionTokens,
}
if usage.PromptTokensDetails != nil && usage.PromptTokensDetails.CachedTokens > 0 {
val := usage.PromptTokensDetails.CachedTokens
result.CacheReadTokens = &val
}
if usage.CompletionTokensDetails != nil && usage.CompletionTokensDetails.ReasoningTokens > 0 {
val := usage.CompletionTokensDetails.ReasoningTokens
result.ReasoningTokens = &val
}
return result
}
// decodeModelsResponse 解码模型列表响应
func decodeModelsResponse(body []byte) (*canonical.CanonicalModelList, error) {
var resp ModelsResponse
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
models := make([]canonical.CanonicalModel, len(resp.Data))
for i, m := range resp.Data {
models[i] = canonical.CanonicalModel{
ID: m.ID,
Name: m.ID,
Created: m.Created,
OwnedBy: m.OwnedBy,
}
}
return &canonical.CanonicalModelList{Models: models}, nil
}
// decodeModelInfoResponse 解码模型详情响应
func decodeModelInfoResponse(body []byte) (*canonical.CanonicalModelInfo, error) {
var resp ModelInfoResponse
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
return &canonical.CanonicalModelInfo{
ID: resp.ID,
Name: resp.ID,
Created: resp.Created,
OwnedBy: resp.OwnedBy,
}, nil
}
// decodeEmbeddingRequest 解码嵌入请求
func decodeEmbeddingRequest(body []byte) (*canonical.CanonicalEmbeddingRequest, error) {
var req EmbeddingRequest
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
return &canonical.CanonicalEmbeddingRequest{
Model: req.Model,
Input: req.Input,
EncodingFormat: req.EncodingFormat,
Dimensions: req.Dimensions,
}, nil
}
// decodeEmbeddingResponse 解码嵌入响应
func decodeEmbeddingResponse(body []byte) (*canonical.CanonicalEmbeddingResponse, error) {
var resp EmbeddingResponse
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
data := make([]canonical.EmbeddingData, len(resp.Data))
for i, d := range resp.Data {
data[i] = canonical.EmbeddingData{Index: d.Index, Embedding: d.Embedding}
}
return &canonical.CanonicalEmbeddingResponse{
Data: data,
Model: resp.Model,
Usage: canonical.EmbeddingUsage{
PromptTokens: resp.Usage.PromptTokens,
TotalTokens: resp.Usage.TotalTokens,
},
}, nil
}
// decodeRerankRequest 解码重排序请求
func decodeRerankRequest(body []byte) (*canonical.CanonicalRerankRequest, error) {
var req RerankRequest
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
return &canonical.CanonicalRerankRequest{
Model: req.Model,
Query: req.Query,
Documents: req.Documents,
TopN: req.TopN,
ReturnDocuments: req.ReturnDocuments,
}, nil
}
// decodeRerankResponse 解码重排序响应
func decodeRerankResponse(body []byte) (*canonical.CanonicalRerankResponse, error) {
var resp RerankResponse
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
results := make([]canonical.RerankResult, len(resp.Results))
for i, r := range resp.Results {
results[i] = canonical.RerankResult{
Index: r.Index,
RelevanceScore: r.RelevanceScore,
Document: r.Document,
}
}
return &canonical.CanonicalRerankResponse{Results: results, Model: resp.Model}, nil
}
// generateID 生成唯一 ID
func generateID() string {
return fmt.Sprintf("call_%d", generateCounter())
}
var idCounter int64
func generateCounter() int64 {
return atomic.AddInt64(&idCounter, 1)
}