refactor: 实现 ConversionEngine 协议转换引擎,替代旧 protocol 包
引入 Canonical Model 和 ProtocolAdapter 架构,支持 OpenAI/Anthropic 协议间 无缝转换,统一 ProxyHandler 替代分散的 OpenAI/Anthropic Handler,简化 ProviderClient 为协议无关的 HTTP 发送器,Provider 新增 protocol 字段。
This commit is contained in:
211
backend/internal/conversion/openai/adapter.go
Normal file
211
backend/internal/conversion/openai/adapter.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// Adapter OpenAI 协议适配器
|
||||
type Adapter struct{}
|
||||
|
||||
// NewAdapter 创建 OpenAI 适配器
|
||||
func NewAdapter() *Adapter {
|
||||
return &Adapter{}
|
||||
}
|
||||
|
||||
var modelInfoRegex = regexp.MustCompile(`^/v1/models/[^/]+$`)
|
||||
|
||||
// ProtocolName 返回协议名称
|
||||
func (a *Adapter) ProtocolName() string { return "openai" }
|
||||
|
||||
// ProtocolVersion 返回协议版本
|
||||
func (a *Adapter) ProtocolVersion() string { return "" }
|
||||
|
||||
// SupportsPassthrough 支持同协议透传
|
||||
func (a *Adapter) SupportsPassthrough() bool { return true }
|
||||
|
||||
// DetectInterfaceType 根据路径检测接口类型
|
||||
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
|
||||
switch {
|
||||
case nativePath == "/v1/chat/completions":
|
||||
return conversion.InterfaceTypeChat
|
||||
case nativePath == "/v1/models":
|
||||
return conversion.InterfaceTypeModels
|
||||
case modelInfoRegex.MatchString(nativePath):
|
||||
return conversion.InterfaceTypeModelInfo
|
||||
case nativePath == "/v1/embeddings":
|
||||
return conversion.InterfaceTypeEmbeddings
|
||||
case nativePath == "/v1/rerank":
|
||||
return conversion.InterfaceTypeRerank
|
||||
default:
|
||||
return conversion.InterfaceTypePassthrough
|
||||
}
|
||||
}
|
||||
|
||||
// BuildUrl 根据接口类型构建 URL
|
||||
func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string {
|
||||
switch interfaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
return "/v1/chat/completions"
|
||||
case conversion.InterfaceTypeModels:
|
||||
return "/v1/models"
|
||||
case conversion.InterfaceTypeEmbeddings:
|
||||
return "/v1/embeddings"
|
||||
case conversion.InterfaceTypeRerank:
|
||||
return "/v1/rerank"
|
||||
default:
|
||||
return nativePath
|
||||
}
|
||||
}
|
||||
|
||||
// BuildHeaders 构建请求头
|
||||
func (a *Adapter) BuildHeaders(provider *conversion.TargetProvider) map[string]string {
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + provider.APIKey,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if org, ok := provider.AdapterConfig["organization"].(string); ok && org != "" {
|
||||
headers["OpenAI-Organization"] = org
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// SupportsInterface 检查是否支持接口类型
|
||||
func (a *Adapter) SupportsInterface(interfaceType conversion.InterfaceType) bool {
|
||||
switch interfaceType {
|
||||
case conversion.InterfaceTypeChat,
|
||||
conversion.InterfaceTypeModels,
|
||||
conversion.InterfaceTypeModelInfo,
|
||||
conversion.InterfaceTypeEmbeddings,
|
||||
conversion.InterfaceTypeRerank:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeRequest 解码请求
|
||||
func (a *Adapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
return decodeRequest(raw)
|
||||
}
|
||||
|
||||
// EncodeRequest 编码请求
|
||||
func (a *Adapter) EncodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return encodeRequest(req, provider)
|
||||
}
|
||||
|
||||
// DecodeResponse 解码响应
|
||||
func (a *Adapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
return decodeResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeResponse 编码响应
|
||||
func (a *Adapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
return encodeResponse(resp)
|
||||
}
|
||||
|
||||
// CreateStreamDecoder 创建流式解码器
|
||||
func (a *Adapter) CreateStreamDecoder() conversion.StreamDecoder {
|
||||
return NewStreamDecoder()
|
||||
}
|
||||
|
||||
// CreateStreamEncoder 创建流式编码器
|
||||
func (a *Adapter) CreateStreamEncoder() conversion.StreamEncoder {
|
||||
return NewStreamEncoder()
|
||||
}
|
||||
|
||||
// EncodeError 编码错误
|
||||
func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
||||
errType := mapErrorCode(err.Code)
|
||||
statusCode := 500
|
||||
|
||||
errMsg := ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: err.Message,
|
||||
Type: errType,
|
||||
Param: nil,
|
||||
Code: string(err.Code),
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(errMsg)
|
||||
return body, statusCode
|
||||
}
|
||||
|
||||
// mapErrorCode 映射错误码到 OpenAI 错误类型
|
||||
func mapErrorCode(code conversion.ErrorCode) string {
|
||||
switch code {
|
||||
case conversion.ErrorCodeInvalidInput,
|
||||
conversion.ErrorCodeMissingRequiredField,
|
||||
conversion.ErrorCodeIncompatibleFeature,
|
||||
conversion.ErrorCodeToolCallParseError,
|
||||
conversion.ErrorCodeJSONParseError,
|
||||
conversion.ErrorCodeProtocolConstraint,
|
||||
conversion.ErrorCodeFieldMappingFailure:
|
||||
return "invalid_request_error"
|
||||
default:
|
||||
return "server_error"
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeModelsResponse 解码模型列表响应
|
||||
func (a *Adapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) {
|
||||
return decodeModelsResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeModelsResponse 编码模型列表响应
|
||||
func (a *Adapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
|
||||
return encodeModelsResponse(list)
|
||||
}
|
||||
|
||||
// DecodeModelInfoResponse 解码模型详情响应
|
||||
func (a *Adapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) {
|
||||
return decodeModelInfoResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeModelInfoResponse 编码模型详情响应
|
||||
func (a *Adapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
|
||||
return encodeModelInfoResponse(info)
|
||||
}
|
||||
|
||||
// DecodeEmbeddingRequest 解码嵌入请求
|
||||
func (a *Adapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
return decodeEmbeddingRequest(raw)
|
||||
}
|
||||
|
||||
// EncodeEmbeddingRequest 编码嵌入请求
|
||||
func (a *Adapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return encodeEmbeddingRequest(req, provider)
|
||||
}
|
||||
|
||||
// DecodeEmbeddingResponse 解码嵌入响应
|
||||
func (a *Adapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) {
|
||||
return decodeEmbeddingResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeEmbeddingResponse 编码嵌入响应
|
||||
func (a *Adapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) {
|
||||
return encodeEmbeddingResponse(resp)
|
||||
}
|
||||
|
||||
// DecodeRerankRequest 解码重排序请求
|
||||
func (a *Adapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
return decodeRerankRequest(raw)
|
||||
}
|
||||
|
||||
// EncodeRerankRequest 编码重排序请求
|
||||
func (a *Adapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return encodeRerankRequest(req, provider)
|
||||
}
|
||||
|
||||
// DecodeRerankResponse 解码重排序响应
|
||||
func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) {
|
||||
return decodeRerankResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeRerankResponse 编码重排序响应
|
||||
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
return encodeRerankResponse(resp)
|
||||
}
|
||||
139
backend/internal/conversion/openai/adapter_test.go
Normal file
139
backend/internal/conversion/openai/adapter_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAdapter_ProtocolName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
assert.Equal(t, "openai", a.ProtocolName())
|
||||
}
|
||||
|
||||
func TestAdapter_SupportsPassthrough(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
assert.True(t, a.SupportsPassthrough())
|
||||
}
|
||||
|
||||
func TestAdapter_DetectInterfaceType(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"聊天补全", "/v1/chat/completions", conversion.InterfaceTypeChat},
|
||||
{"模型列表", "/v1/models", conversion.InterfaceTypeModels},
|
||||
{"模型详情", "/v1/models/gpt-4", conversion.InterfaceTypeModelInfo},
|
||||
{"嵌入接口", "/v1/embeddings", conversion.InterfaceTypeEmbeddings},
|
||||
{"重排序接口", "/v1/rerank", conversion.InterfaceTypeRerank},
|
||||
{"未知路径", "/v1/unknown", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.DetectInterfaceType(tt.path)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nativePath string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected string
|
||||
}{
|
||||
{"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/v1/chat/completions"},
|
||||
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
|
||||
{"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/v1/embeddings"},
|
||||
{"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/v1/rerank"},
|
||||
{"默认透传", "/v1/other", conversion.InterfaceTypePassthrough, "/v1/other"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.BuildUrl(tt.nativePath, tt.interfaceType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildHeaders(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("基本头", func(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.openai.com", "sk-test123", "gpt-4")
|
||||
headers := a.BuildHeaders(provider)
|
||||
assert.Equal(t, "Bearer sk-test123", headers["Authorization"])
|
||||
assert.Equal(t, "application/json", headers["Content-Type"])
|
||||
_, hasOrg := headers["OpenAI-Organization"]
|
||||
assert.False(t, hasOrg)
|
||||
})
|
||||
|
||||
t.Run("带组织", func(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.openai.com", "sk-test123", "gpt-4")
|
||||
provider.AdapterConfig["organization"] = "org-abc"
|
||||
headers := a.BuildHeaders(provider)
|
||||
assert.Equal(t, "org-abc", headers["OpenAI-Organization"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdapter_SupportsInterface(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected bool
|
||||
}{
|
||||
{"聊天", conversion.InterfaceTypeChat, true},
|
||||
{"模型", conversion.InterfaceTypeModels, true},
|
||||
{"模型详情", conversion.InterfaceTypeModelInfo, true},
|
||||
{"嵌入", conversion.InterfaceTypeEmbeddings, true},
|
||||
{"重排序", conversion.InterfaceTypeRerank, true},
|
||||
{"透传", conversion.InterfaceTypePassthrough, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.SupportsInterface(tt.interfaceType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_EncodeError_InvalidInput(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")
|
||||
|
||||
body, statusCode := a.EncodeError(convErr)
|
||||
require.Equal(t, 500, statusCode)
|
||||
|
||||
var resp ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(body, &resp))
|
||||
assert.Equal(t, "参数无效", resp.Error.Message)
|
||||
assert.Equal(t, "invalid_request_error", resp.Error.Type)
|
||||
}
|
||||
|
||||
func TestAdapter_EncodeError_ServerError(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
convErr := conversion.NewConversionError(conversion.ErrorCodeStreamStateError, "流状态错误")
|
||||
|
||||
body, statusCode := a.EncodeError(convErr)
|
||||
require.Equal(t, 500, statusCode)
|
||||
|
||||
var resp ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(body, &resp))
|
||||
assert.Equal(t, "server_error", resp.Error.Type)
|
||||
assert.Equal(t, "流状态错误", resp.Error.Message)
|
||||
}
|
||||
669
backend/internal/conversion/openai/decoder.go
Normal file
669
backend/internal/conversion/openai/decoder.go
Normal file
@@ -0,0 +1,669 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// decodeRequest 将 OpenAI 请求解码为 Canonical 请求
|
||||
func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
|
||||
var req ChatCompletionRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 OpenAI 请求失败").WithCause(err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(req.Model) == "" {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "model 字段不能为空")
|
||||
}
|
||||
if len(req.Messages) == 0 {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "messages 字段不能为空")
|
||||
}
|
||||
|
||||
// 废弃字段兼容
|
||||
decodeDeprecatedFields(&req)
|
||||
|
||||
system, messages := decodeSystemPrompt(req.Messages)
|
||||
|
||||
var canonicalMsgs []canonical.CanonicalMessage
|
||||
for _, msg := range messages {
|
||||
decoded, err := decodeMessage(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
canonicalMsgs = append(canonicalMsgs, decoded...)
|
||||
}
|
||||
|
||||
tools := decodeTools(req.Tools)
|
||||
toolChoice := decodeToolChoice(req.ToolChoice)
|
||||
params := decodeParameters(&req)
|
||||
outputFormat := decodeOutputFormat(req.ResponseFormat)
|
||||
thinking := decodeThinking(req.ReasoningEffort)
|
||||
|
||||
var parallelToolUse *bool
|
||||
if req.ParallelToolCalls != nil {
|
||||
parallelToolUse = req.ParallelToolCalls
|
||||
}
|
||||
|
||||
return &canonical.CanonicalRequest{
|
||||
Model: req.Model,
|
||||
System: system,
|
||||
Messages: canonicalMsgs,
|
||||
Tools: tools,
|
||||
ToolChoice: toolChoice,
|
||||
Parameters: params,
|
||||
Thinking: thinking,
|
||||
Stream: req.Stream,
|
||||
UserID: req.User,
|
||||
OutputFormat: outputFormat,
|
||||
ParallelToolUse: parallelToolUse,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeSystemPrompt 提取 system 和 developer 消息
|
||||
func decodeSystemPrompt(messages []Message) (any, []Message) {
|
||||
var systemParts []string
|
||||
var remaining []Message
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" || msg.Role == "developer" {
|
||||
text := extractText(msg.Content)
|
||||
if text != "" {
|
||||
systemParts = append(systemParts, text)
|
||||
}
|
||||
} else {
|
||||
remaining = append(remaining, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if len(systemParts) == 0 {
|
||||
return nil, remaining
|
||||
}
|
||||
return strings.Join(systemParts, "\n\n"), remaining
|
||||
}
|
||||
|
||||
// extractText 从 content 提取文本
|
||||
func extractText(content any) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []any:
|
||||
var parts []string
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if t, ok := m["type"].(string); ok && t == "text" {
|
||||
if text, ok := m["text"].(string); ok {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "")
|
||||
case nil:
|
||||
return ""
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// decodeMessage 解码 OpenAI 消息
|
||||
func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
blocks := decodeUserContent(msg.Content)
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: blocks}}, nil
|
||||
|
||||
case "assistant":
|
||||
var blocks []canonical.ContentBlock
|
||||
// 处理 content
|
||||
if msg.Content != nil {
|
||||
switch v := msg.Content.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(v))
|
||||
}
|
||||
default:
|
||||
parts := decodeContentParts(msg.Content)
|
||||
for _, p := range parts {
|
||||
if p.Type == "text" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(p.Text))
|
||||
} else if p.Type == "refusal" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(p.Refusal))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// refusal 顶层字段
|
||||
if msg.Refusal != "" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(msg.Refusal))
|
||||
}
|
||||
// reasoning_content 非标准字段
|
||||
if msg.ReasoningContent != "" {
|
||||
blocks = append(blocks, canonical.NewThinkingBlock(msg.ReasoningContent))
|
||||
}
|
||||
// tool_calls
|
||||
for _, tc := range msg.ToolCalls {
|
||||
var input json.RawMessage
|
||||
if tc.Type == "custom" && tc.Custom != nil {
|
||||
input = json.RawMessage(fmt.Sprintf("%q", tc.Custom.Input))
|
||||
} else if tc.Function != nil {
|
||||
parsed := json.RawMessage(tc.Function.Arguments)
|
||||
if !json.Valid(parsed) {
|
||||
parsed = json.RawMessage("{}")
|
||||
}
|
||||
input = parsed
|
||||
} else {
|
||||
input = json.RawMessage("{}")
|
||||
}
|
||||
name := ""
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
} else if tc.Custom != nil {
|
||||
name = tc.Custom.Name
|
||||
}
|
||||
blocks = append(blocks, canonical.NewToolUseBlock(tc.ID, name, input))
|
||||
}
|
||||
// 已废弃 function_call
|
||||
if msg.FunctionCall != nil {
|
||||
input := json.RawMessage(msg.FunctionCall.Arguments)
|
||||
if !json.Valid(input) {
|
||||
input = json.RawMessage("{}")
|
||||
}
|
||||
blocks = append(blocks, canonical.NewToolUseBlock(generateID(), msg.FunctionCall.Name, input))
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
|
||||
|
||||
case "tool":
|
||||
content := extractText(msg.Content)
|
||||
isErr := false
|
||||
block := canonical.ContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: msg.ToolCallID,
|
||||
Content: json.RawMessage(fmt.Sprintf("%q", content)),
|
||||
IsError: &isErr,
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleTool, Content: []canonical.ContentBlock{block}}}, nil
|
||||
|
||||
case "function":
|
||||
content := extractText(msg.Content)
|
||||
isErr := false
|
||||
block := canonical.ContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: msg.Name,
|
||||
Content: json.RawMessage(fmt.Sprintf("%q", content)),
|
||||
IsError: &isErr,
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleTool, Content: []canonical.ContentBlock{block}}}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// decodeUserContent 解码用户内容
|
||||
func decodeUserContent(content any) []canonical.ContentBlock {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(v)}
|
||||
case []any:
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
t, _ := m["type"].(string)
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
blocks = append(blocks, canonical.NewTextBlock(text))
|
||||
case "image_url":
|
||||
blocks = append(blocks, canonical.ContentBlock{Type: "image"})
|
||||
case "input_audio":
|
||||
blocks = append(blocks, canonical.ContentBlock{Type: "audio"})
|
||||
case "file":
|
||||
blocks = append(blocks, canonical.ContentBlock{Type: "file"})
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(blocks) > 0 {
|
||||
return blocks
|
||||
}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
case nil:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
default:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}
|
||||
}
|
||||
}
|
||||
|
||||
// contentPart 内容部分
|
||||
type contentPart struct {
|
||||
Type string
|
||||
Text string
|
||||
Refusal string
|
||||
}
|
||||
|
||||
// decodeContentParts 解码内容部分
|
||||
func decodeContentParts(content any) []contentPart {
|
||||
parts, ok := content.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var result []contentPart
|
||||
for _, item := range parts {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
t, _ := m["type"].(string)
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
result = append(result, contentPart{Type: "text", Text: text})
|
||||
case "refusal":
|
||||
refusal, _ := m["refusal"].(string)
|
||||
result = append(result, contentPart{Type: "refusal", Refusal: refusal})
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// decodeTools 解码工具定义
|
||||
func decodeTools(tools []Tool) []canonical.CanonicalTool {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
var result []canonical.CanonicalTool
|
||||
for _, tool := range tools {
|
||||
if tool.Type == "function" && tool.Function != nil {
|
||||
result = append(result, canonical.CanonicalTool{
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
InputSchema: tool.Function.Parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// decodeToolChoice 解码工具选择
|
||||
func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
if toolChoice == nil {
|
||||
return nil
|
||||
}
|
||||
switch v := toolChoice.(type) {
|
||||
case string:
|
||||
switch v {
|
||||
case "auto":
|
||||
return canonical.NewToolChoiceAuto()
|
||||
case "none":
|
||||
return canonical.NewToolChoiceNone()
|
||||
case "required":
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
case map[string]any:
|
||||
t, _ := v["type"].(string)
|
||||
switch t {
|
||||
case "function":
|
||||
if fn, ok := v["function"].(map[string]any); ok {
|
||||
name, _ := fn["name"].(string)
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
case "custom":
|
||||
if custom, ok := v["custom"].(map[string]any); ok {
|
||||
name, _ := custom["name"].(string)
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
case "allowed_tools":
|
||||
if at, ok := v["allowed_tools"].(map[string]any); ok {
|
||||
mode, _ := at["mode"].(string)
|
||||
if mode == "required" {
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
return canonical.NewToolChoiceAuto()
|
||||
}
|
||||
return canonical.NewToolChoiceAuto()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeParameters 解码请求参数
|
||||
func decodeParameters(req *ChatCompletionRequest) canonical.RequestParameters {
|
||||
params := canonical.RequestParameters{
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
FrequencyPenalty: req.FrequencyPenalty,
|
||||
PresencePenalty: req.PresencePenalty,
|
||||
}
|
||||
if req.MaxCompletionTokens != nil {
|
||||
params.MaxTokens = req.MaxCompletionTokens
|
||||
} else if req.MaxTokens != nil {
|
||||
params.MaxTokens = req.MaxTokens
|
||||
}
|
||||
if req.Stop != nil {
|
||||
params.StopSequences = normalizeStop(req.Stop)
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// normalizeStop 规范化 stop 参数
|
||||
func normalizeStop(stop any) []string {
|
||||
switch v := stop.(type) {
|
||||
case string:
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{v}
|
||||
case []any:
|
||||
var result []string
|
||||
for _, s := range v {
|
||||
if str, ok := s.(string); ok && str != "" {
|
||||
result = append(result, str)
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
case []string:
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeOutputFormat 解码输出格式
|
||||
func decodeOutputFormat(format *ResponseFormat) *canonical.OutputFormat {
|
||||
if format == nil {
|
||||
return nil
|
||||
}
|
||||
switch format.Type {
|
||||
case "json_object":
|
||||
return &canonical.OutputFormat{Type: "json_object"}
|
||||
case "json_schema":
|
||||
if format.JSONSchema != nil {
|
||||
return &canonical.OutputFormat{
|
||||
Type: "json_schema",
|
||||
Name: format.JSONSchema.Name,
|
||||
Schema: format.JSONSchema.Schema,
|
||||
Strict: format.JSONSchema.Strict,
|
||||
}
|
||||
}
|
||||
return &canonical.OutputFormat{Type: "json_schema"}
|
||||
case "text":
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeThinking 解码推理配置
|
||||
func decodeThinking(reasoningEffort string) *canonical.ThinkingConfig {
|
||||
if reasoningEffort == "" {
|
||||
return nil
|
||||
}
|
||||
if reasoningEffort == "none" {
|
||||
return &canonical.ThinkingConfig{Type: "disabled"}
|
||||
}
|
||||
effort := reasoningEffort
|
||||
if effort == "minimal" {
|
||||
effort = "low"
|
||||
}
|
||||
return &canonical.ThinkingConfig{Type: "enabled", Effort: effort}
|
||||
}
|
||||
|
||||
// decodeDeprecatedFields 废弃字段兼容
|
||||
func decodeDeprecatedFields(req *ChatCompletionRequest) {
|
||||
if len(req.Tools) == 0 && len(req.Functions) > 0 {
|
||||
req.Tools = make([]Tool, len(req.Functions))
|
||||
for i, f := range req.Functions {
|
||||
req.Tools[i] = Tool{
|
||||
Type: "function",
|
||||
Function: &FunctionDef{
|
||||
Name: f.Name,
|
||||
Description: f.Description,
|
||||
Parameters: f.Parameters,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.ToolChoice == nil && req.FunctionCall != nil {
|
||||
switch v := req.FunctionCall.(type) {
|
||||
case string:
|
||||
switch v {
|
||||
case "none":
|
||||
req.ToolChoice = "none"
|
||||
case "auto":
|
||||
req.ToolChoice = "auto"
|
||||
}
|
||||
case map[string]any:
|
||||
if name, ok := v["name"].(string); ok {
|
||||
req.ToolChoice = map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{"name": name},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// decodeResponse 将 OpenAI 响应解码为 Canonical 响应
|
||||
func decodeResponse(body []byte) (*canonical.CanonicalResponse, error) {
|
||||
var resp ChatCompletionResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 OpenAI 响应失败").WithCause(err)
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return &canonical.CanonicalResponse{
|
||||
ID: resp.ID,
|
||||
Model: resp.Model,
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("")},
|
||||
Usage: canonical.CanonicalUsage{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
var blocks []canonical.ContentBlock
|
||||
|
||||
if choice.Message != nil {
|
||||
if choice.Message.Content != nil {
|
||||
text := extractText(choice.Message.Content)
|
||||
if text != "" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(text))
|
||||
}
|
||||
}
|
||||
if choice.Message.Refusal != "" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(choice.Message.Refusal))
|
||||
}
|
||||
if choice.Message.ReasoningContent != "" {
|
||||
blocks = append(blocks, canonical.NewThinkingBlock(choice.Message.ReasoningContent))
|
||||
}
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
var input json.RawMessage
|
||||
name := ""
|
||||
if tc.Type == "custom" && tc.Custom != nil {
|
||||
input = json.RawMessage(fmt.Sprintf("%q", tc.Custom.Input))
|
||||
name = tc.Custom.Name
|
||||
} else if tc.Function != nil {
|
||||
input = json.RawMessage(tc.Function.Arguments)
|
||||
if !json.Valid(input) {
|
||||
input = json.RawMessage("{}")
|
||||
}
|
||||
name = tc.Function.Name
|
||||
} else {
|
||||
input = json.RawMessage("{}")
|
||||
}
|
||||
blocks = append(blocks, canonical.NewToolUseBlock(tc.ID, name, input))
|
||||
}
|
||||
}
|
||||
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
|
||||
var stopReason *canonical.StopReason
|
||||
if choice.FinishReason != nil {
|
||||
sr := mapFinishReason(*choice.FinishReason)
|
||||
stopReason = &sr
|
||||
}
|
||||
|
||||
return &canonical.CanonicalResponse{
|
||||
ID: resp.ID,
|
||||
Model: resp.Model,
|
||||
Content: blocks,
|
||||
StopReason: stopReason,
|
||||
Usage: decodeUsage(resp.Usage),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// mapFinishReason 映射结束原因
|
||||
func mapFinishReason(reason string) canonical.StopReason {
|
||||
switch reason {
|
||||
case "stop":
|
||||
return canonical.StopReasonEndTurn
|
||||
case "length":
|
||||
return canonical.StopReasonMaxTokens
|
||||
case "tool_calls":
|
||||
return canonical.StopReasonToolUse
|
||||
case "function_call":
|
||||
return canonical.StopReasonToolUse
|
||||
case "content_filter":
|
||||
return canonical.StopReasonContentFilter
|
||||
default:
|
||||
return canonical.StopReasonEndTurn
|
||||
}
|
||||
}
|
||||
|
||||
// decodeUsage 解码用量
|
||||
func decodeUsage(usage *Usage) canonical.CanonicalUsage {
|
||||
if usage == nil {
|
||||
return canonical.CanonicalUsage{}
|
||||
}
|
||||
result := canonical.CanonicalUsage{
|
||||
InputTokens: usage.PromptTokens,
|
||||
OutputTokens: usage.CompletionTokens,
|
||||
}
|
||||
if usage.PromptTokensDetails != nil && usage.PromptTokensDetails.CachedTokens > 0 {
|
||||
val := usage.PromptTokensDetails.CachedTokens
|
||||
result.CacheReadTokens = &val
|
||||
}
|
||||
if usage.CompletionTokensDetails != nil && usage.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
val := usage.CompletionTokensDetails.ReasoningTokens
|
||||
result.ReasoningTokens = &val
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// decodeModelsResponse 解码模型列表响应
|
||||
func decodeModelsResponse(body []byte) (*canonical.CanonicalModelList, error) {
|
||||
var resp ModelsResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
models := make([]canonical.CanonicalModel, len(resp.Data))
|
||||
for i, m := range resp.Data {
|
||||
models[i] = canonical.CanonicalModel{
|
||||
ID: m.ID,
|
||||
Name: m.ID,
|
||||
Created: m.Created,
|
||||
OwnedBy: m.OwnedBy,
|
||||
}
|
||||
}
|
||||
return &canonical.CanonicalModelList{Models: models}, nil
|
||||
}
|
||||
|
||||
// decodeModelInfoResponse 解码模型详情响应
|
||||
func decodeModelInfoResponse(body []byte) (*canonical.CanonicalModelInfo, error) {
|
||||
var resp ModelInfoResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &canonical.CanonicalModelInfo{
|
||||
ID: resp.ID,
|
||||
Name: resp.ID,
|
||||
Created: resp.Created,
|
||||
OwnedBy: resp.OwnedBy,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeEmbeddingRequest 解码嵌入请求
|
||||
func decodeEmbeddingRequest(body []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
var req EmbeddingRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &canonical.CanonicalEmbeddingRequest{
|
||||
Model: req.Model,
|
||||
Input: req.Input,
|
||||
EncodingFormat: req.EncodingFormat,
|
||||
Dimensions: req.Dimensions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeEmbeddingResponse 解码嵌入响应
|
||||
func decodeEmbeddingResponse(body []byte) (*canonical.CanonicalEmbeddingResponse, error) {
|
||||
var resp EmbeddingResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := make([]canonical.EmbeddingData, len(resp.Data))
|
||||
for i, d := range resp.Data {
|
||||
data[i] = canonical.EmbeddingData{Index: d.Index, Embedding: d.Embedding}
|
||||
}
|
||||
return &canonical.CanonicalEmbeddingResponse{
|
||||
Data: data,
|
||||
Model: resp.Model,
|
||||
Usage: canonical.EmbeddingUsage{
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeRerankRequest 解码重排序请求
|
||||
func decodeRerankRequest(body []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
var req RerankRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &canonical.CanonicalRerankRequest{
|
||||
Model: req.Model,
|
||||
Query: req.Query,
|
||||
Documents: req.Documents,
|
||||
TopN: req.TopN,
|
||||
ReturnDocuments: req.ReturnDocuments,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeRerankResponse 解码重排序响应
|
||||
func decodeRerankResponse(body []byte) (*canonical.CanonicalRerankResponse, error) {
|
||||
var resp RerankResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results := make([]canonical.RerankResult, len(resp.Results))
|
||||
for i, r := range resp.Results {
|
||||
results[i] = canonical.RerankResult{
|
||||
Index: r.Index,
|
||||
RelevanceScore: r.RelevanceScore,
|
||||
Document: r.Document,
|
||||
}
|
||||
}
|
||||
return &canonical.CanonicalRerankResponse{Results: results, Model: resp.Model}, nil
|
||||
}
|
||||
|
||||
// generateID 生成唯一 ID
|
||||
func generateID() string {
|
||||
return fmt.Sprintf("call_%d", generateCounter())
|
||||
}
|
||||
|
||||
var idCounter int64
|
||||
|
||||
func generateCounter() int64 {
|
||||
return atomic.AddInt64(&idCounter, 1)
|
||||
}
|
||||
411
backend/internal/conversion/openai/decoder_test.go
Normal file
411
backend/internal/conversion/openai/decoder_test.go
Normal file
@@ -0,0 +1,411 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecodeRequest_BasicChat(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "你好"}
|
||||
],
|
||||
"temperature": 0.7
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", req.Model)
|
||||
assert.Len(t, req.Messages, 1)
|
||||
assert.Equal(t, canonical.RoleUser, req.Messages[0].Role)
|
||||
assert.NotNil(t, req.Parameters.Temperature)
|
||||
assert.Equal(t, 0.7, *req.Parameters.Temperature)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_SystemAndDeveloper(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "system", "content": "你是助手"},
|
||||
{"role": "developer", "content": "额外指令"},
|
||||
{"role": "user", "content": "你好"}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "你是助手\n\n额外指令", req.System)
|
||||
assert.Len(t, req.Messages, 1)
|
||||
assert.Equal(t, canonical.RoleUser, req.Messages[0].Role)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ToolCalls(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "天气"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {"name": "get_weather", "arguments": "{\"city\":\"北京\"}"}
|
||||
}]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, req.Messages, 2)
|
||||
assistantMsg := req.Messages[1]
|
||||
assert.Equal(t, canonical.RoleAssistant, assistantMsg.Role)
|
||||
found := false
|
||||
for _, b := range assistantMsg.Content {
|
||||
if b.Type == "tool_use" {
|
||||
found = true
|
||||
assert.Equal(t, "call_123", b.ID)
|
||||
assert.Equal(t, "get_weather", b.Name)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ToolMessage(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "天气"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "晴天 25°C"
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
toolMsg := req.Messages[2]
|
||||
assert.Equal(t, canonical.RoleTool, toolMsg.Role)
|
||||
assert.Equal(t, "call_1", toolMsg.Content[0].ToolUseID)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_MissingModel(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
|
||||
_, err := decodeRequest(body)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "INVALID_INPUT")
|
||||
}
|
||||
|
||||
func TestDecodeRequest_MissingMessages(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4"}`)
|
||||
_, err := decodeRequest(body)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "INVALID_INPUT")
|
||||
}
|
||||
|
||||
func TestDecodeRequest_DeprecatedFunctions(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"functions": [{
|
||||
"name": "get_weather",
|
||||
"description": "获取天气",
|
||||
"parameters": {"type":"object","properties":{"city":{"type":"string"}}}
|
||||
}]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, req.Tools, 1)
|
||||
assert.Equal(t, "get_weather", req.Tools[0].Name)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Basic(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "chatcmpl-123",
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "你好"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "chatcmpl-123", resp.ID)
|
||||
assert.Equal(t, "gpt-4", resp.Model)
|
||||
assert.Len(t, resp.Content, 1)
|
||||
assert.Equal(t, "你好", resp.Content[0].Text)
|
||||
assert.NotNil(t, resp.StopReason)
|
||||
assert.Equal(t, canonical.StopReasonEndTurn, *resp.StopReason)
|
||||
assert.Equal(t, 10, resp.Usage.InputTokens)
|
||||
assert.Equal(t, 5, resp.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_ToolCalls(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "chatcmpl-456",
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_abc",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{\"q\":\"test\"}"}
|
||||
}]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}]
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
found := false
|
||||
for _, b := range resp.Content {
|
||||
if b.Type == "tool_use" {
|
||||
found = true
|
||||
assert.Equal(t, "call_abc", b.ID)
|
||||
assert.Equal(t, "search", b.Name)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, canonical.StopReasonToolUse, *resp.StopReason)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Thinking(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "chatcmpl-789",
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "回答",
|
||||
"reasoning_content": "思考过程"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Content, 2)
|
||||
assert.Equal(t, "回答", resp.Content[0].Text)
|
||||
assert.Equal(t, "thinking", resp.Content[1].Type)
|
||||
assert.Equal(t, "思考过程", resp.Content[1].Thinking)
|
||||
}
|
||||
|
||||
func TestDecodeModelsResponse(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{"id": "gpt-4", "object": "model", "created": 1700000000, "owned_by": "openai"},
|
||||
{"id": "gpt-3.5-turbo", "object": "model", "created": 1700000001, "owned_by": "openai"}
|
||||
]
|
||||
}`)
|
||||
|
||||
list, err := decodeModelsResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, list.Models, 2)
|
||||
assert.Equal(t, "gpt-4", list.Models[0].ID)
|
||||
assert.Equal(t, "gpt-3.5-turbo", list.Models[1].ID)
|
||||
assert.Equal(t, int64(1700000000), list.Models[0].Created)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_InvalidJSON(t *testing.T) {
|
||||
_, err := decodeRequest([]byte(`invalid json`))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "JSON_PARSE_ERROR")
|
||||
}
|
||||
|
||||
func TestDecodeRequest_Parameters(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"temperature": 0.5,
|
||||
"max_completion_tokens": 2048,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0.1,
|
||||
"presence_penalty": 0.2,
|
||||
"stop": ["STOP"]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, req.Parameters.Temperature)
|
||||
assert.Equal(t, 0.5, *req.Parameters.Temperature)
|
||||
assert.NotNil(t, req.Parameters.MaxTokens)
|
||||
assert.Equal(t, 2048, *req.Parameters.MaxTokens)
|
||||
assert.NotNil(t, req.Parameters.TopP)
|
||||
assert.Equal(t, 0.9, *req.Parameters.TopP)
|
||||
assert.NotNil(t, req.Parameters.FrequencyPenalty)
|
||||
assert.Equal(t, 0.1, *req.Parameters.FrequencyPenalty)
|
||||
assert.NotNil(t, req.Parameters.PresencePenalty)
|
||||
assert.Equal(t, 0.2, *req.Parameters.PresencePenalty)
|
||||
assert.Equal(t, []string{"STOP"}, req.Parameters.StopSequences)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ToolChoice(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
jsonBody string
|
||||
want *canonical.ToolChoice
|
||||
}{
|
||||
{
|
||||
name: "auto",
|
||||
jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":"auto"}`,
|
||||
want: canonical.NewToolChoiceAuto(),
|
||||
},
|
||||
{
|
||||
name: "none",
|
||||
jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":"none"}`,
|
||||
want: canonical.NewToolChoiceNone(),
|
||||
},
|
||||
{
|
||||
name: "required",
|
||||
jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":"required"}`,
|
||||
want: canonical.NewToolChoiceAny(),
|
||||
},
|
||||
{
|
||||
name: "named",
|
||||
jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":{"type":"function","function":{"name":"x"}}}`,
|
||||
want: canonical.NewToolChoiceNamed("x"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := decodeRequest([]byte(tt.jsonBody))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.ToolChoice)
|
||||
assert.Equal(t, tt.want.Type, req.ToolChoice.Type)
|
||||
assert.Equal(t, tt.want.Name, req.ToolChoice.Name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeRequest_OutputFormat_JSONSchema(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"response_format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "my_schema",
|
||||
"schema": {"type":"object","properties":{"name":{"type":"string"}}},
|
||||
"strict": true
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.OutputFormat)
|
||||
assert.Equal(t, "json_schema", req.OutputFormat.Type)
|
||||
assert.Equal(t, "my_schema", req.OutputFormat.Name)
|
||||
assert.NotNil(t, req.OutputFormat.Schema)
|
||||
require.NotNil(t, req.OutputFormat.Strict)
|
||||
assert.True(t, *req.OutputFormat.Strict)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_OutputFormat_JSON(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"response_format": {"type": "json_object"}
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.OutputFormat)
|
||||
assert.Equal(t, "json_object", req.OutputFormat.Type)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_StopReasons(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
finishReason string
|
||||
want canonical.StopReason
|
||||
}{
|
||||
{"stop→end_turn", "stop", canonical.StopReasonEndTurn},
|
||||
{"length→max_tokens", "length", canonical.StopReasonMaxTokens},
|
||||
{"tool_calls→tool_use", "tool_calls", canonical.StopReasonToolUse},
|
||||
{"content_filter→content_filter", "content_filter", canonical.StopReasonContentFilter},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body := []byte(fmt.Sprintf(`{
|
||||
"id": "resp-1",
|
||||
"model": "gpt-4",
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "ok"}, "finish_reason": "%s"}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}
|
||||
}`, tt.finishReason))
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.StopReason)
|
||||
assert.Equal(t, tt.want, *resp.StopReason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Usage(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "resp-1",
|
||||
"model": "gpt-4",
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "ok"}, "finish_reason": "stop"}],
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
"prompt_tokens_details": {"cached_tokens": 80}
|
||||
}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, resp.Usage.InputTokens)
|
||||
assert.Equal(t, 50, resp.Usage.OutputTokens)
|
||||
require.NotNil(t, resp.Usage.CacheReadTokens)
|
||||
assert.Equal(t, 80, *resp.Usage.CacheReadTokens)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Refusal(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "resp-1",
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": null, "refusal": "我拒绝回答"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
found := false
|
||||
for _, b := range resp.Content {
|
||||
if b.Text == "我拒绝回答" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
532
backend/internal/conversion/openai/encoder.go
Normal file
532
backend/internal/conversion/openai/encoder.go
Normal file
@@ -0,0 +1,532 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// encodeRequest 将 Canonical 请求编码为 OpenAI 请求
|
||||
func encodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
result := map[string]any{
|
||||
"model": provider.ModelName,
|
||||
"stream": req.Stream,
|
||||
}
|
||||
|
||||
// 系统消息 + 消息
|
||||
messages := encodeSystemAndMessages(req)
|
||||
result["messages"] = messages
|
||||
|
||||
// 参数
|
||||
encodeParametersInto(req, result)
|
||||
|
||||
// 工具
|
||||
if len(req.Tools) > 0 {
|
||||
tools := make([]map[string]any, len(req.Tools))
|
||||
for i, t := range req.Tools {
|
||||
tools[i] = map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": t.Name,
|
||||
"description": t.Description,
|
||||
"parameters": t.InputSchema,
|
||||
},
|
||||
}
|
||||
}
|
||||
result["tools"] = tools
|
||||
}
|
||||
if req.ToolChoice != nil {
|
||||
result["tool_choice"] = encodeToolChoice(req.ToolChoice)
|
||||
}
|
||||
|
||||
// 公共字段
|
||||
if req.UserID != "" {
|
||||
result["user"] = req.UserID
|
||||
}
|
||||
if req.OutputFormat != nil {
|
||||
result["response_format"] = encodeOutputFormat(req.OutputFormat)
|
||||
}
|
||||
if req.ParallelToolUse != nil {
|
||||
result["parallel_tool_calls"] = *req.ParallelToolUse
|
||||
}
|
||||
if req.Thinking != nil {
|
||||
switch req.Thinking.Type {
|
||||
case "disabled":
|
||||
result["reasoning_effort"] = "none"
|
||||
default:
|
||||
if req.Thinking.Effort != "" {
|
||||
result["reasoning_effort"] = req.Thinking.Effort
|
||||
} else {
|
||||
result["reasoning_effort"] = "medium"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 OpenAI 请求失败").WithCause(err)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// encodeSystemAndMessages 编码系统消息和消息列表
|
||||
func encodeSystemAndMessages(req *canonical.CanonicalRequest) []map[string]any {
|
||||
var messages []map[string]any
|
||||
|
||||
// 系统消息
|
||||
switch v := req.System.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "system",
|
||||
"content": v,
|
||||
})
|
||||
}
|
||||
case []canonical.SystemBlock:
|
||||
var parts []string
|
||||
for _, b := range v {
|
||||
parts = append(parts, b.Text)
|
||||
}
|
||||
text := joinStrings(parts, "\n\n")
|
||||
if text != "" {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "system",
|
||||
"content": text,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 消息
|
||||
for _, msg := range req.Messages {
|
||||
encoded := encodeMessage(msg)
|
||||
messages = append(messages, encoded...)
|
||||
}
|
||||
|
||||
// 合并连续同角色消息
|
||||
return mergeConsecutiveRoles(messages)
|
||||
}
|
||||
|
||||
// encodeMessage 编码单条消息
|
||||
func encodeMessage(msg canonical.CanonicalMessage) []map[string]any {
|
||||
switch msg.Role {
|
||||
case canonical.RoleUser:
|
||||
return []map[string]any{{
|
||||
"role": "user",
|
||||
"content": encodeUserContent(msg.Content),
|
||||
}}
|
||||
case canonical.RoleAssistant:
|
||||
m := map[string]any{"role": "assistant"}
|
||||
var textParts []string
|
||||
var toolUses []canonical.ContentBlock
|
||||
|
||||
for _, b := range msg.Content {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
textParts = append(textParts, b.Text)
|
||||
case "tool_use":
|
||||
toolUses = append(toolUses, b)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolUses) > 0 {
|
||||
if len(textParts) > 0 {
|
||||
m["content"] = joinStrings(textParts, "")
|
||||
} else {
|
||||
m["content"] = nil
|
||||
}
|
||||
tcs := make([]map[string]any, len(toolUses))
|
||||
for i, tu := range toolUses {
|
||||
tcs[i] = map[string]any{
|
||||
"id": tu.ID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": tu.Name,
|
||||
"arguments": string(tu.Input),
|
||||
},
|
||||
}
|
||||
}
|
||||
m["tool_calls"] = tcs
|
||||
} else if len(textParts) > 0 {
|
||||
m["content"] = joinStrings(textParts, "")
|
||||
} else {
|
||||
m["content"] = ""
|
||||
}
|
||||
return []map[string]any{m}
|
||||
|
||||
case canonical.RoleTool:
|
||||
for _, b := range msg.Content {
|
||||
if b.Type == "tool_result" {
|
||||
var contentStr string
|
||||
if b.Content != nil {
|
||||
var s string
|
||||
if json.Unmarshal(b.Content, &s) == nil {
|
||||
contentStr = s
|
||||
} else {
|
||||
contentStr = string(b.Content)
|
||||
}
|
||||
}
|
||||
return []map[string]any{{
|
||||
"role": "tool",
|
||||
"tool_call_id": b.ToolUseID,
|
||||
"content": contentStr,
|
||||
}}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeUserContent 编码用户内容
|
||||
func encodeUserContent(blocks []canonical.ContentBlock) any {
|
||||
if len(blocks) == 1 && blocks[0].Type == "text" {
|
||||
return blocks[0].Text
|
||||
}
|
||||
parts := make([]map[string]any, 0, len(blocks))
|
||||
for _, b := range blocks {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
parts = append(parts, map[string]any{"type": "text", "text": b.Text})
|
||||
case "image":
|
||||
parts = append(parts, map[string]any{"type": "image_url"})
|
||||
case "audio":
|
||||
parts = append(parts, map[string]any{"type": "input_audio"})
|
||||
case "file":
|
||||
parts = append(parts, map[string]any{"type": "file"})
|
||||
}
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// encodeToolChoice 编码工具选择
|
||||
func encodeToolChoice(choice *canonical.ToolChoice) any {
|
||||
switch choice.Type {
|
||||
case "auto":
|
||||
return "auto"
|
||||
case "none":
|
||||
return "none"
|
||||
case "any":
|
||||
return "required"
|
||||
case "tool":
|
||||
return map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": choice.Name,
|
||||
},
|
||||
}
|
||||
}
|
||||
return "auto"
|
||||
}
|
||||
|
||||
// encodeParametersInto 编码参数到结果 map
|
||||
func encodeParametersInto(req *canonical.CanonicalRequest, result map[string]any) {
|
||||
if req.Parameters.MaxTokens != nil {
|
||||
result["max_completion_tokens"] = *req.Parameters.MaxTokens
|
||||
}
|
||||
if req.Parameters.Temperature != nil {
|
||||
result["temperature"] = *req.Parameters.Temperature
|
||||
}
|
||||
if req.Parameters.TopP != nil {
|
||||
result["top_p"] = *req.Parameters.TopP
|
||||
}
|
||||
if req.Parameters.FrequencyPenalty != nil {
|
||||
result["frequency_penalty"] = *req.Parameters.FrequencyPenalty
|
||||
}
|
||||
if req.Parameters.PresencePenalty != nil {
|
||||
result["presence_penalty"] = *req.Parameters.PresencePenalty
|
||||
}
|
||||
if len(req.Parameters.StopSequences) > 0 {
|
||||
result["stop"] = req.Parameters.StopSequences
|
||||
}
|
||||
}
|
||||
|
||||
// encodeOutputFormat 编码输出格式
|
||||
func encodeOutputFormat(format *canonical.OutputFormat) map[string]any {
|
||||
switch format.Type {
|
||||
case "json_object":
|
||||
return map[string]any{"type": "json_object"}
|
||||
case "json_schema":
|
||||
m := map[string]any{"type": "json_schema"}
|
||||
schema := map[string]any{
|
||||
"name": format.Name,
|
||||
}
|
||||
if format.Schema != nil {
|
||||
schema["schema"] = format.Schema
|
||||
}
|
||||
if format.Strict != nil {
|
||||
schema["strict"] = *format.Strict
|
||||
}
|
||||
m["json_schema"] = schema
|
||||
return m
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeResponse 将 Canonical 响应编码为 OpenAI 响应
|
||||
func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
var textParts []string
|
||||
var thinkingParts []string
|
||||
var toolUses []canonical.ContentBlock
|
||||
|
||||
for _, b := range resp.Content {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
textParts = append(textParts, b.Text)
|
||||
case "thinking":
|
||||
thinkingParts = append(thinkingParts, b.Thinking)
|
||||
case "tool_use":
|
||||
toolUses = append(toolUses, b)
|
||||
}
|
||||
}
|
||||
|
||||
message := map[string]any{"role": "assistant"}
|
||||
if len(toolUses) > 0 {
|
||||
if len(textParts) > 0 {
|
||||
message["content"] = joinStrings(textParts, "")
|
||||
} else {
|
||||
message["content"] = nil
|
||||
}
|
||||
tcs := make([]map[string]any, len(toolUses))
|
||||
for i, tu := range toolUses {
|
||||
tcs[i] = map[string]any{
|
||||
"id": tu.ID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": tu.Name,
|
||||
"arguments": string(tu.Input),
|
||||
},
|
||||
}
|
||||
}
|
||||
message["tool_calls"] = tcs
|
||||
} else if len(textParts) > 0 {
|
||||
message["content"] = joinStrings(textParts, "")
|
||||
} else {
|
||||
message["content"] = ""
|
||||
}
|
||||
|
||||
if len(thinkingParts) > 0 {
|
||||
message["reasoning_content"] = joinStrings(thinkingParts, "")
|
||||
}
|
||||
|
||||
var finishReason *string
|
||||
if resp.StopReason != nil {
|
||||
fr := mapCanonicalToFinishReason(*resp.StopReason)
|
||||
finishReason = &fr
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"id": resp.ID,
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": resp.Model,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"finish_reason": finishReason,
|
||||
}},
|
||||
"usage": encodeUsage(resp.Usage),
|
||||
}
|
||||
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 OpenAI 响应失败").WithCause(err)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// mapCanonicalToFinishReason 映射 Canonical 停止原因到 OpenAI finish_reason
|
||||
func mapCanonicalToFinishReason(reason canonical.StopReason) string {
|
||||
switch reason {
|
||||
case canonical.StopReasonEndTurn:
|
||||
return "stop"
|
||||
case canonical.StopReasonMaxTokens:
|
||||
return "length"
|
||||
case canonical.StopReasonToolUse:
|
||||
return "tool_calls"
|
||||
case canonical.StopReasonContentFilter:
|
||||
return "content_filter"
|
||||
case canonical.StopReasonStopSequence:
|
||||
return "stop"
|
||||
case canonical.StopReasonRefusal:
|
||||
return "stop"
|
||||
default:
|
||||
return "stop"
|
||||
}
|
||||
}
|
||||
|
||||
// encodeUsage 编码用量
|
||||
func encodeUsage(usage canonical.CanonicalUsage) map[string]any {
|
||||
result := map[string]any{
|
||||
"prompt_tokens": usage.InputTokens,
|
||||
"completion_tokens": usage.OutputTokens,
|
||||
"total_tokens": usage.InputTokens + usage.OutputTokens,
|
||||
}
|
||||
if usage.CacheReadTokens != nil && *usage.CacheReadTokens > 0 {
|
||||
result["prompt_tokens_details"] = map[string]any{
|
||||
"cached_tokens": *usage.CacheReadTokens,
|
||||
}
|
||||
}
|
||||
if usage.ReasoningTokens != nil && *usage.ReasoningTokens > 0 {
|
||||
result["completion_tokens_details"] = map[string]any{
|
||||
"reasoning_tokens": *usage.ReasoningTokens,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// encodeModelsResponse 编码模型列表响应
|
||||
func encodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
|
||||
data := make([]map[string]any, len(list.Models))
|
||||
for i, m := range list.Models {
|
||||
created := int64(0)
|
||||
if m.Created != 0 {
|
||||
created = m.Created
|
||||
}
|
||||
ownedBy := "unknown"
|
||||
if m.OwnedBy != "" {
|
||||
ownedBy = m.OwnedBy
|
||||
}
|
||||
data[i] = map[string]any{
|
||||
"id": m.ID,
|
||||
"object": "model",
|
||||
"created": created,
|
||||
"owned_by": ownedBy,
|
||||
}
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
// encodeModelInfoResponse 编码模型详情响应
|
||||
func encodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
|
||||
created := int64(0)
|
||||
if info.Created != 0 {
|
||||
created = info.Created
|
||||
}
|
||||
ownedBy := "unknown"
|
||||
if info.OwnedBy != "" {
|
||||
ownedBy = info.OwnedBy
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"id": info.ID,
|
||||
"object": "model",
|
||||
"created": created,
|
||||
"owned_by": ownedBy,
|
||||
})
|
||||
}
|
||||
|
||||
// encodeEmbeddingRequest 编码嵌入请求
|
||||
func encodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
result := map[string]any{
|
||||
"model": provider.ModelName,
|
||||
"input": req.Input,
|
||||
}
|
||||
if req.EncodingFormat != "" {
|
||||
result["encoding_format"] = req.EncodingFormat
|
||||
}
|
||||
if req.Dimensions != nil {
|
||||
result["dimensions"] = *req.Dimensions
|
||||
}
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// encodeEmbeddingResponse 编码嵌入响应
|
||||
func encodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) {
|
||||
data := make([]map[string]any, len(resp.Data))
|
||||
for i, d := range resp.Data {
|
||||
data[i] = map[string]any{
|
||||
"index": d.Index,
|
||||
"embedding": d.Embedding,
|
||||
}
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": resp.Model,
|
||||
"usage": resp.Usage,
|
||||
})
|
||||
}
|
||||
|
||||
// encodeRerankRequest 编码重排序请求
|
||||
func encodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
result := map[string]any{
|
||||
"model": provider.ModelName,
|
||||
"query": req.Query,
|
||||
"documents": req.Documents,
|
||||
}
|
||||
if req.TopN != nil {
|
||||
result["top_n"] = *req.TopN
|
||||
}
|
||||
if req.ReturnDocuments != nil {
|
||||
result["return_documents"] = *req.ReturnDocuments
|
||||
}
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// encodeRerankResponse 编码重排序响应
|
||||
func encodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
results := make([]map[string]any, len(resp.Results))
|
||||
for i, r := range resp.Results {
|
||||
m := map[string]any{
|
||||
"index": r.Index,
|
||||
"relevance_score": r.RelevanceScore,
|
||||
}
|
||||
if r.Document != nil {
|
||||
m["document"] = *r.Document
|
||||
}
|
||||
results[i] = m
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"results": results,
|
||||
"model": resp.Model,
|
||||
})
|
||||
}
|
||||
|
||||
// joinStrings 拼接字符串切片
|
||||
func joinStrings(parts []string, sep string) string {
|
||||
result := ""
|
||||
for i, p := range parts {
|
||||
if i > 0 {
|
||||
result += sep
|
||||
}
|
||||
result += p
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// mergeConsecutiveRoles 合并连续同角色消息(拼接内容)
|
||||
func mergeConsecutiveRoles(messages []map[string]any) []map[string]any {
|
||||
if len(messages) <= 1 {
|
||||
return messages
|
||||
}
|
||||
var result []map[string]any
|
||||
for _, msg := range messages {
|
||||
if len(result) > 0 {
|
||||
lastRole := result[len(result)-1]["role"]
|
||||
currRole := msg["role"]
|
||||
if lastRole == currRole {
|
||||
lastContent := result[len(result)-1]["content"]
|
||||
currContent := msg["content"]
|
||||
switch lv := lastContent.(type) {
|
||||
case string:
|
||||
if cv, ok := currContent.(string); ok {
|
||||
result[len(result)-1]["content"] = lv + cv
|
||||
}
|
||||
case []any:
|
||||
if cv, ok := currContent.([]any); ok {
|
||||
result[len(result)-1]["content"] = append(lv, cv...)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
result = append(result, msg)
|
||||
}
|
||||
return result
|
||||
}
|
||||
355
backend/internal/conversion/openai/encoder_test.go
Normal file
355
backend/internal/conversion/openai/encoder_test.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncodeRequest_Basic(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Stream: true,
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "my-model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "my-model", result["model"])
|
||||
assert.Equal(t, true, result["stream"])
|
||||
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 1)
|
||||
}
|
||||
|
||||
func TestEncodeRequest_SystemInjection(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
System: "你是助手",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
assert.Len(t, msgs, 2)
|
||||
firstMsg := msgs[0].(map[string]any)
|
||||
assert.Equal(t, "system", firstMsg["role"])
|
||||
assert.Equal(t, "你是助手", firstMsg["content"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolCalls(t *testing.T) {
|
||||
input := json.RawMessage(`{"city":"北京"}`)
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{
|
||||
{
|
||||
Role: canonical.RoleAssistant,
|
||||
Content: []canonical.ContentBlock{
|
||||
canonical.NewToolUseBlock("call_1", "get_weather", input),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
assistantMsg := msgs[0].(map[string]any)
|
||||
toolCalls, ok := assistantMsg["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, toolCalls, 1)
|
||||
tc := toolCalls[0].(map[string]any)
|
||||
assert.Equal(t, "call_1", tc["id"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_Thinking(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Thinking: &canonical.ThinkingConfig{Type: "enabled", Effort: "high"},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "high", result["reasoning_effort"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_Basic(t *testing.T) {
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "resp-1", result["id"])
|
||||
assert.Equal(t, "chat.completion", result["object"])
|
||||
|
||||
choices := result["choices"].([]any)
|
||||
choice := choices[0].(map[string]any)
|
||||
msg := choice["message"].(map[string]any)
|
||||
assert.Equal(t, "你好", msg["content"])
|
||||
assert.Equal(t, "stop", choice["finish_reason"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
sr := canonical.StopReasonToolUse
|
||||
input := json.RawMessage(`{"q":"test"}`)
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "resp-2",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
msg := choices[0].(map[string]any)["message"].(map[string]any)
|
||||
tcs, ok := msg["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, tcs, 1)
|
||||
}
|
||||
|
||||
func TestEncodeModelsResponse(t *testing.T) {
|
||||
list := &canonical.CanonicalModelList{
|
||||
Models: []canonical.CanonicalModel{
|
||||
{ID: "gpt-4", Created: 1700000000, OwnedBy: "openai"},
|
||||
{ID: "gpt-3.5-turbo", Created: 1700000001, OwnedBy: "openai"},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := encodeModelsResponse(list)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "list", result["object"])
|
||||
data := result["data"].([]any)
|
||||
assert.Len(t, data, 2)
|
||||
}
|
||||
|
||||
func TestMergeConsecutiveRoles(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{"role": "user", "content": "A"},
|
||||
{"role": "user", "content": "B"},
|
||||
{"role": "assistant", "content": "C"},
|
||||
{"role": "assistant", "content": "D"},
|
||||
}
|
||||
|
||||
result := mergeConsecutiveRoles(messages)
|
||||
assert.Len(t, result, 2)
|
||||
assert.Equal(t, "AB", result[0]["content"])
|
||||
assert.Equal(t, "CD", result[1]["content"])
|
||||
}
|
||||
|
||||
func TestMergeConsecutiveRoles_NotOverwriting(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{"role": "user", "content": "你好"},
|
||||
{"role": "user", "content": "世界"},
|
||||
}
|
||||
|
||||
result := mergeConsecutiveRoles(messages)
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, "你好世界", result[0]["content"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolChoice_Auto(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
ToolChoice: canonical.NewToolChoiceAuto(),
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "auto", result["tool_choice"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolChoice_None(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
ToolChoice: canonical.NewToolChoiceNone(),
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "none", result["tool_choice"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolChoice_Required(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
ToolChoice: canonical.NewToolChoiceAny(),
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "required", result["tool_choice"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolChoice_Named(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
ToolChoice: canonical.NewToolChoiceNamed("my_func"),
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
tc, ok := result["tool_choice"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "function", tc["type"])
|
||||
fn, ok := tc["function"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "my_func", fn["name"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_OutputFormat_JSONSchema(t *testing.T) {
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`)
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
OutputFormat: &canonical.OutputFormat{
|
||||
Type: "json_schema",
|
||||
Name: "my_schema",
|
||||
Schema: schema,
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
rf, ok := result["response_format"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "json_schema", rf["type"])
|
||||
js, ok := rf["json_schema"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "my_schema", js["name"])
|
||||
assert.NotNil(t, js["schema"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_OutputFormat_Text(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
_, hasResponseFormat := result["response_format"]
|
||||
assert.False(t, hasResponseFormat)
|
||||
}
|
||||
|
||||
func TestEncodeResponse_Thinking(t *testing.T) {
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "resp-thinking",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{
|
||||
canonical.NewTextBlock("回答"),
|
||||
canonical.NewThinkingBlock("思考过程"),
|
||||
},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
msg := choices[0].(map[string]any)["message"].(map[string]any)
|
||||
assert.Equal(t, "回答", msg["content"])
|
||||
assert.Equal(t, "思考过程", msg["reasoning_content"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_Parameters(t *testing.T) {
|
||||
temp := 0.5
|
||||
maxTokens := 2048
|
||||
topP := 0.9
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Parameters: canonical.RequestParameters{
|
||||
Temperature: &temp,
|
||||
MaxTokens: &maxTokens,
|
||||
TopP: &topP,
|
||||
StopSequences: []string{"STOP", "END"},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, temp, result["temperature"])
|
||||
assert.Equal(t, float64(maxTokens), result["max_completion_tokens"])
|
||||
assert.Equal(t, topP, result["top_p"])
|
||||
stop, ok := result["stop"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, stop, 2)
|
||||
assert.Equal(t, "STOP", stop[0])
|
||||
assert.Equal(t, "END", stop[1])
|
||||
}
|
||||
230
backend/internal/conversion/openai/stream_decoder.go
Normal file
230
backend/internal/conversion/openai/stream_decoder.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamDecoder OpenAI 流式解码器
|
||||
type StreamDecoder struct {
|
||||
messageStarted bool
|
||||
openBlocks map[int]string
|
||||
textBlockIndex int
|
||||
thinkingBlockIndex int
|
||||
refusalBlockIndex int
|
||||
toolCallIDMap map[int]string
|
||||
toolCallNameMap map[int]string
|
||||
nextToolCallIdx int
|
||||
utf8Remainder []byte
|
||||
accumulatedUsage *canonical.CanonicalUsage
|
||||
}
|
||||
|
||||
// NewStreamDecoder 创建 OpenAI 流式解码器
|
||||
func NewStreamDecoder() *StreamDecoder {
|
||||
return &StreamDecoder{
|
||||
openBlocks: make(map[int]string),
|
||||
toolCallIDMap: make(map[int]string),
|
||||
toolCallNameMap: make(map[int]string),
|
||||
textBlockIndex: -1,
|
||||
thinkingBlockIndex: -1,
|
||||
refusalBlockIndex: -1,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 处理原始 SSE chunk
|
||||
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
// 处理 UTF-8 残余
|
||||
data := rawChunk
|
||||
if len(d.utf8Remainder) > 0 {
|
||||
data = append(d.utf8Remainder, rawChunk...)
|
||||
d.utf8Remainder = nil
|
||||
}
|
||||
|
||||
var events []canonical.CanonicalStreamEvent
|
||||
|
||||
// 解析 SSE data 行
|
||||
lines := strings.Split(string(data), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
if payload == "[DONE]" {
|
||||
events = append(events, d.flushOpenBlocks()...)
|
||||
return events
|
||||
}
|
||||
|
||||
chunkEvents := d.processDataChunk([]byte(payload))
|
||||
events = append(events, chunkEvents...)
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// Flush 刷新解码器状态
|
||||
func (d *StreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||
return nil
|
||||
}
|
||||
|
||||
// processDataChunk 处理单个 data chunk
|
||||
func (d *StreamDecoder) processDataChunk(data []byte) []canonical.CanonicalStreamEvent {
|
||||
// 检查 UTF-8 完整性
|
||||
if !utf8.Valid(data) {
|
||||
validEnd := len(data)
|
||||
for !utf8.Valid(data[:validEnd]) {
|
||||
validEnd--
|
||||
}
|
||||
d.utf8Remainder = append(d.utf8Remainder, data[validEnd:]...)
|
||||
data = data[:validEnd]
|
||||
}
|
||||
|
||||
var chunk StreamChunk
|
||||
if err := json.Unmarshal(data, &chunk); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var events []canonical.CanonicalStreamEvent
|
||||
|
||||
// 首个 chunk: MessageStart
|
||||
if !d.messageStarted {
|
||||
events = append(events, canonical.NewMessageStartEvent(chunk.ID, chunk.Model))
|
||||
d.messageStarted = true
|
||||
}
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta == nil {
|
||||
continue
|
||||
}
|
||||
delta := choice.Delta
|
||||
|
||||
// text content
|
||||
if delta.Content != nil {
|
||||
text := ""
|
||||
switch v := delta.Content.(type) {
|
||||
case string:
|
||||
text = v
|
||||
default:
|
||||
text = fmt.Sprintf("%v", v)
|
||||
}
|
||||
if text != "" {
|
||||
if _, ok := d.openBlocks[d.textBlockIndex]; !ok || d.textBlockIndex < 0 {
|
||||
d.textBlockIndex = d.allocateBlockIndex()
|
||||
d.openBlocks[d.textBlockIndex] = "text"
|
||||
events = append(events, canonical.NewContentBlockStartEvent(d.textBlockIndex,
|
||||
canonical.StreamContentBlock{Type: "text", Text: ""}))
|
||||
}
|
||||
events = append(events, canonical.NewContentBlockDeltaEvent(d.textBlockIndex,
|
||||
canonical.StreamDelta{Type: string(canonical.DeltaTypeText), Text: text}))
|
||||
}
|
||||
}
|
||||
|
||||
// reasoning_content (非标准)
|
||||
if delta.ReasoningContent != "" {
|
||||
if _, ok := d.openBlocks[d.thinkingBlockIndex]; !ok || d.thinkingBlockIndex < 0 {
|
||||
d.thinkingBlockIndex = d.allocateBlockIndex()
|
||||
d.openBlocks[d.thinkingBlockIndex] = "thinking"
|
||||
events = append(events, canonical.NewContentBlockStartEvent(d.thinkingBlockIndex,
|
||||
canonical.StreamContentBlock{Type: "thinking", Thinking: ""}))
|
||||
}
|
||||
events = append(events, canonical.NewContentBlockDeltaEvent(d.thinkingBlockIndex,
|
||||
canonical.StreamDelta{Type: string(canonical.DeltaTypeThinking), Thinking: delta.ReasoningContent}))
|
||||
}
|
||||
|
||||
// refusal
|
||||
if delta.Refusal != "" {
|
||||
if _, ok := d.openBlocks[d.refusalBlockIndex]; !ok || d.refusalBlockIndex < 0 {
|
||||
d.refusalBlockIndex = d.allocateBlockIndex()
|
||||
d.openBlocks[d.refusalBlockIndex] = "text"
|
||||
events = append(events, canonical.NewContentBlockStartEvent(d.refusalBlockIndex,
|
||||
canonical.StreamContentBlock{Type: "text", Text: ""}))
|
||||
}
|
||||
events = append(events, canonical.NewContentBlockDeltaEvent(d.refusalBlockIndex,
|
||||
canonical.StreamDelta{Type: string(canonical.DeltaTypeText), Text: delta.Refusal}))
|
||||
}
|
||||
|
||||
// tool_calls
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
for _, tc := range delta.ToolCalls {
|
||||
tcIdx := 0
|
||||
if tc.Index != nil {
|
||||
tcIdx = *tc.Index
|
||||
}
|
||||
|
||||
if tc.ID != "" {
|
||||
// 新 tool call block
|
||||
d.toolCallIDMap[tcIdx] = tc.ID
|
||||
if tc.Function != nil {
|
||||
d.toolCallNameMap[tcIdx] = tc.Function.Name
|
||||
}
|
||||
blockIdx := d.allocateBlockIndex()
|
||||
d.openBlocks[blockIdx] = fmt.Sprintf("tool_use_%d", tcIdx)
|
||||
name := d.toolCallNameMap[tcIdx]
|
||||
events = append(events, canonical.NewContentBlockStartEvent(blockIdx,
|
||||
canonical.StreamContentBlock{Type: "tool_use", ID: tc.ID, Name: name}))
|
||||
}
|
||||
|
||||
// 查找该 tool call 的 block index
|
||||
blockIdx := d.findToolUseBlockIndex(tcIdx)
|
||||
if tc.Function != nil && tc.Function.Arguments != "" {
|
||||
events = append(events, canonical.NewContentBlockDeltaEvent(blockIdx,
|
||||
canonical.StreamDelta{Type: string(canonical.DeltaTypeInputJSON), PartialJSON: tc.Function.Arguments}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// finish_reason
|
||||
if choice.FinishReason != nil && *choice.FinishReason != "" {
|
||||
events = append(events, d.flushOpenBlocks()...)
|
||||
sr := mapFinishReason(*choice.FinishReason)
|
||||
events = append(events, canonical.NewMessageDeltaEventWithUsage(sr, nil))
|
||||
events = append(events, canonical.NewMessageStopEvent())
|
||||
}
|
||||
}
|
||||
|
||||
// usage chunk (choices 为空)
|
||||
if len(chunk.Choices) == 0 && chunk.Usage != nil {
|
||||
usage := decodeUsage(chunk.Usage)
|
||||
d.accumulatedUsage = &usage
|
||||
events = append(events, canonical.NewMessageDeltaEventWithUsage(canonical.StopReasonEndTurn, &usage))
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// allocateBlockIndex 分配 block 索引
|
||||
func (d *StreamDecoder) allocateBlockIndex() int {
|
||||
maxIdx := -1
|
||||
for k := range d.openBlocks {
|
||||
if k > maxIdx {
|
||||
maxIdx = k
|
||||
}
|
||||
}
|
||||
return maxIdx + 1
|
||||
}
|
||||
|
||||
// findToolUseBlockIndex 查找 tool use block 索引
|
||||
func (d *StreamDecoder) findToolUseBlockIndex(tcIdx int) int {
|
||||
key := fmt.Sprintf("tool_use_%d", tcIdx)
|
||||
for blockIdx, typ := range d.openBlocks {
|
||||
if typ == key {
|
||||
return blockIdx
|
||||
}
|
||||
}
|
||||
return d.allocateBlockIndex()
|
||||
}
|
||||
|
||||
// flushOpenBlocks 关闭所有 open blocks
|
||||
func (d *StreamDecoder) flushOpenBlocks() []canonical.CanonicalStreamEvent {
|
||||
var events []canonical.CanonicalStreamEvent
|
||||
for idx := range d.openBlocks {
|
||||
events = append(events, canonical.NewContentBlockStopEvent(idx))
|
||||
}
|
||||
d.openBlocks = make(map[int]string)
|
||||
return events
|
||||
}
|
||||
355
backend/internal/conversion/openai/stream_decoder_test.go
Normal file
355
backend/internal/conversion/openai/stream_decoder_test.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func makeSSEData(payload string) []byte {
|
||||
return []byte("data: " + payload + "\n\n")
|
||||
}
|
||||
|
||||
func TestStreamDecoder_BasicText(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"object": "chat.completion.chunk",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"content": "你好"},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := makeSSEData(string(data))
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
foundStart := false
|
||||
foundDelta := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventMessageStart {
|
||||
foundStart = true
|
||||
assert.Equal(t, "chatcmpl-1", e.Message.ID)
|
||||
}
|
||||
if e.Type == canonical.EventContentBlockDelta && e.Delta != nil {
|
||||
foundDelta = true
|
||||
assert.Equal(t, "text_delta", e.Delta.Type)
|
||||
assert.Equal(t, "你好", e.Delta.Text)
|
||||
}
|
||||
}
|
||||
assert.True(t, foundStart)
|
||||
assert.True(t, foundDelta)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ToolCalls(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
idx := 0
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"index": &idx,
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"北京\"}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := makeSSEData(string(data))
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
found := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventContentBlockStart && e.ContentBlock != nil && e.ContentBlock.Type == "tool_use" {
|
||||
found = true
|
||||
assert.Equal(t, "call_1", e.ContentBlock.ID)
|
||||
assert.Equal(t, "get_weather", e.ContentBlock.Name)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_Thinking(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"reasoning_content": "思考中",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := makeSSEData(string(data))
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
found := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventContentBlockDelta && e.Delta != nil && e.Delta.Type == "thinking_delta" {
|
||||
found = true
|
||||
assert.Equal(t, "思考中", e.Delta.Thinking)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_FinishReason(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := makeSSEData(string(data))
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
foundStop := false
|
||||
foundMsgStop := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventMessageDelta && e.StopReason != nil {
|
||||
foundStop = true
|
||||
assert.Equal(t, canonical.StopReasonEndTurn, *e.StopReason)
|
||||
}
|
||||
if e.Type == canonical.EventMessageStop {
|
||||
foundMsgStop = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundStop)
|
||||
assert.True(t, foundMsgStop)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_DoneSignal(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
// 先发送一个文本 chunk
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"content": "hi"},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := append(makeSSEData(string(data)), []byte("data: [DONE]\n\n")...)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
// 应该包含 block stop 事件([DONE] 触发 flushOpenBlocks)
|
||||
foundBlockStop := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventContentBlockStop {
|
||||
foundBlockStop = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundBlockStop)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_RefusalReuse(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
// 连续两个 refusal delta chunk
|
||||
for _, text := range []string{"拒绝", "原因"} {
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"refusal": text},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := makeSSEData(string(data))
|
||||
events := d.ProcessChunk(raw)
|
||||
_ = events
|
||||
}
|
||||
|
||||
// 检查只创建了一个 text block(refusal 复用同一个 block)
|
||||
assert.Contains(t, d.openBlocks, d.refusalBlockIndex)
|
||||
}
|
||||
|
||||
func makeChunkSSE(chunk map[string]any) []byte {
|
||||
data, _ := json.Marshal(chunk)
|
||||
return []byte("data: " + string(data) + "\n\n")
|
||||
}
|
||||
|
||||
func TestStreamDecoder_UsageChunk(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-usage",
|
||||
"object": "chat.completion.chunk",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
}
|
||||
raw := makeChunkSSE(chunk)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
found := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventMessageDelta {
|
||||
found = true
|
||||
require.NotNil(t, e.Usage)
|
||||
assert.Equal(t, 100, e.Usage.InputTokens)
|
||||
assert.Equal(t, 50, e.Usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
idx0 := 0
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"index": &idx0,
|
||||
"id": "call_a",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "func_a",
|
||||
"arguments": "{}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
idx1 := 1
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"index": &idx1,
|
||||
"id": "call_b",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "func_b",
|
||||
"arguments": "{}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
events1 := d.ProcessChunk(makeChunkSSE(chunk1))
|
||||
require.NotEmpty(t, events1)
|
||||
|
||||
events2 := d.ProcessChunk(makeChunkSSE(chunk2))
|
||||
require.NotEmpty(t, events2)
|
||||
|
||||
blockIndices := map[int]bool{}
|
||||
for _, e := range append(events1, events2...) {
|
||||
if e.Type == canonical.EventContentBlockStart && e.ContentBlock != nil && e.ContentBlock.Type == "tool_use" {
|
||||
require.NotNil(t, e.Index)
|
||||
blockIndices[*e.Index] = true
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, len(blockIndices), "两个 tool call 应分配不同的 block 索引")
|
||||
}
|
||||
|
||||
func TestStreamDecoder_Flush(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
result := d.Flush()
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"content": "你好"},
|
||||
},
|
||||
},
|
||||
}
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"content": "世界"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
raw := append(makeChunkSSE(chunk1), makeChunkSSE(chunk2)...)
|
||||
events := d.ProcessChunk(raw)
|
||||
|
||||
deltas := []string{}
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventContentBlockDelta && e.Delta != nil && e.Delta.Type == "text_delta" {
|
||||
deltas = append(deltas, e.Delta.Text)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, []string{"你好", "世界"}, deltas)
|
||||
}
|
||||
217
backend/internal/conversion/openai/stream_encoder.go
Normal file
217
backend/internal/conversion/openai/stream_encoder.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamEncoder OpenAI 流式编码器
|
||||
type StreamEncoder struct {
|
||||
bufferedStart *canonical.CanonicalStreamEvent
|
||||
toolCallIndexMap map[string]int
|
||||
nextToolCallIndex int
|
||||
}
|
||||
|
||||
// NewStreamEncoder 创建 OpenAI 流式编码器
|
||||
func NewStreamEncoder() *StreamEncoder {
|
||||
return &StreamEncoder{
|
||||
toolCallIndexMap: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeEvent 编码 Canonical 事件为 SSE chunk
|
||||
func (e *StreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
switch event.Type {
|
||||
case canonical.EventMessageStart:
|
||||
return e.encodeMessageStart(event)
|
||||
case canonical.EventContentBlockStart:
|
||||
return e.bufferBlockStart(event)
|
||||
case canonical.EventContentBlockDelta:
|
||||
return e.encodeContentBlockDelta(event)
|
||||
case canonical.EventContentBlockStop:
|
||||
return nil
|
||||
case canonical.EventMessageDelta:
|
||||
return e.encodeMessageDelta(event)
|
||||
case canonical.EventMessageStop:
|
||||
return [][]byte{[]byte("data: [DONE]\n\n")}
|
||||
case canonical.EventPing, canonical.EventError:
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush 刷新缓冲区
|
||||
func (e *StreamEncoder) Flush() [][]byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeMessageStart 编码消息开始事件
|
||||
func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
id := ""
|
||||
model := ""
|
||||
if event.Message != nil {
|
||||
id = event.Message.ID
|
||||
model = event.Message.Model
|
||||
}
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"role": "assistant"},
|
||||
}},
|
||||
}
|
||||
|
||||
return e.marshalChunk(chunk)
|
||||
}
|
||||
|
||||
// bufferBlockStart 缓冲 block start 事件
|
||||
func (e *StreamEncoder) bufferBlockStart(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
e.bufferedStart = &event
|
||||
if event.ContentBlock != nil && event.ContentBlock.Type == "tool_use" {
|
||||
idx := e.nextToolCallIndex
|
||||
e.nextToolCallIndex++
|
||||
if event.ContentBlock.ID != "" {
|
||||
e.toolCallIndexMap[event.ContentBlock.ID] = idx
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeContentBlockDelta 编码内容块增量事件
|
||||
func (e *StreamEncoder) encodeContentBlockDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Delta == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch canonical.DeltaType(event.Delta.Type) {
|
||||
case canonical.DeltaTypeText:
|
||||
return e.encodeTextDelta(event)
|
||||
case canonical.DeltaTypeInputJSON:
|
||||
return e.encodeInputJSONDelta(event)
|
||||
case canonical.DeltaTypeThinking:
|
||||
return e.encodeThinkingDelta(event)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeTextDelta 编码文本增量
|
||||
func (e *StreamEncoder) encodeTextDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
delta := map[string]any{
|
||||
"content": event.Delta.Text,
|
||||
}
|
||||
if e.bufferedStart != nil {
|
||||
e.bufferedStart = nil
|
||||
}
|
||||
return e.encodeDelta(delta)
|
||||
}
|
||||
|
||||
// encodeInputJSONDelta 编码 JSON 输入增量
|
||||
func (e *StreamEncoder) encodeInputJSONDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if e.bufferedStart != nil && e.bufferedStart.ContentBlock != nil {
|
||||
// 首次 delta,含 id 和 name
|
||||
start := e.bufferedStart.ContentBlock
|
||||
tcIdx := 0
|
||||
if start.ID != "" {
|
||||
tcIdx = e.toolCallIndexMap[start.ID]
|
||||
}
|
||||
delta := map[string]any{
|
||||
"tool_calls": []map[string]any{{
|
||||
"index": tcIdx,
|
||||
"id": start.ID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": start.Name,
|
||||
"arguments": event.Delta.PartialJSON,
|
||||
},
|
||||
}},
|
||||
}
|
||||
e.bufferedStart = nil
|
||||
return e.encodeDelta(delta)
|
||||
}
|
||||
|
||||
// 后续 delta,仅含 arguments
|
||||
// 通过 index 查找 tool call
|
||||
tcIdx := 0
|
||||
if event.Index != nil {
|
||||
for id, idx := range e.toolCallIndexMap {
|
||||
if idx == tcIdx {
|
||||
_ = id
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
delta := map[string]any{
|
||||
"tool_calls": []map[string]any{{
|
||||
"index": tcIdx,
|
||||
"function": map[string]any{
|
||||
"arguments": event.Delta.PartialJSON,
|
||||
},
|
||||
}},
|
||||
}
|
||||
return e.encodeDelta(delta)
|
||||
}
|
||||
|
||||
// encodeThinkingDelta 编码思考增量
|
||||
func (e *StreamEncoder) encodeThinkingDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
delta := map[string]any{
|
||||
"reasoning_content": event.Delta.Thinking,
|
||||
}
|
||||
if e.bufferedStart != nil {
|
||||
e.bufferedStart = nil
|
||||
}
|
||||
return e.encodeDelta(delta)
|
||||
}
|
||||
|
||||
// encodeMessageDelta 编码消息增量事件
|
||||
func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
var chunks [][]byte
|
||||
|
||||
if event.StopReason != nil {
|
||||
fr := mapCanonicalToFinishReason(*event.StopReason)
|
||||
chunk := map[string]any{
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"delta": map[string]any{},
|
||||
"finish_reason": fr,
|
||||
}},
|
||||
}
|
||||
chunks = append(chunks, e.marshalChunk(chunk)...)
|
||||
}
|
||||
|
||||
if event.Usage != nil {
|
||||
chunk := map[string]any{
|
||||
"choices": []map[string]any{},
|
||||
"usage": encodeUsage(*event.Usage),
|
||||
}
|
||||
chunks = append(chunks, e.marshalChunk(chunk)...)
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// encodeDelta 编码 delta 到 SSE chunk
|
||||
func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte {
|
||||
chunk := map[string]any{
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
}},
|
||||
}
|
||||
return e.marshalChunk(chunk)
|
||||
}
|
||||
|
||||
// marshalChunk 序列化 chunk 为 SSE data
|
||||
func (e *StreamEncoder) marshalChunk(chunk map[string]any) [][]byte {
|
||||
data, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return [][]byte{[]byte(fmt.Sprintf("data: %s\n\n", data))}
|
||||
}
|
||||
172
backend/internal/conversion/openai/stream_encoder_test.go
Normal file
172
backend/internal/conversion/openai/stream_encoder_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStreamEncoder_MessageStart(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewMessageStartEvent("chatcmpl-1", "gpt-4")
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "data: "))
|
||||
assert.Contains(t, s, "chatcmpl-1")
|
||||
assert.Contains(t, s, "chat.completion.chunk")
|
||||
|
||||
var payload map[string]any
|
||||
data := strings.TrimPrefix(s, "data: ")
|
||||
data = strings.TrimRight(data, "\n")
|
||||
require.NoError(t, json.Unmarshal([]byte(data), &payload))
|
||||
choices := payload["choices"].([]any)
|
||||
delta := choices[0].(map[string]any)["delta"].(map[string]any)
|
||||
assert.Equal(t, "assistant", delta["role"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_TextDelta(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{Type: "text_delta", Text: "你好"})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "你好")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageStop(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewMessageStopEvent()
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
assert.Equal(t, "data: [DONE]\n\n", string(chunks[0]))
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Buffering(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
|
||||
// ContentBlockStart 应被缓冲,不输出
|
||||
startEvent := canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: ""})
|
||||
chunks := e.EncodeEvent(startEvent)
|
||||
assert.Nil(t, chunks)
|
||||
assert.NotNil(t, e.bufferedStart)
|
||||
|
||||
// 第一个 delta 触发输出(清空缓冲)
|
||||
deltaEvent := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{Type: "text_delta", Text: "hello"})
|
||||
chunks = e.EncodeEvent(deltaEvent)
|
||||
require.NotEmpty(t, chunks)
|
||||
assert.Nil(t, e.bufferedStart)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockStop_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
idx := 0
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventContentBlockStop,
|
||||
Index: &idx,
|
||||
}
|
||||
chunks := e.EncodeEvent(event)
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Ping_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewPingEvent()
|
||||
chunks := e.EncodeEvent(event)
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Error_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewErrorEvent("test_error", "测试错误")
|
||||
chunks := e.EncodeEvent(event)
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Flush_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
chunks := e.Flush()
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ThinkingDelta(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{
|
||||
Type: string(canonical.DeltaTypeThinking),
|
||||
Thinking: "思考内容",
|
||||
})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "reasoning_content")
|
||||
assert.Contains(t, s, "思考内容")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_InputJSONDelta(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
|
||||
e.EncodeEvent(canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: "call_1",
|
||||
Name: "get_weather",
|
||||
}))
|
||||
|
||||
event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{
|
||||
Type: string(canonical.DeltaTypeInputJSON),
|
||||
PartialJSON: "{\"city\":\"北京\"}",
|
||||
})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.NotEmpty(t, chunks)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "tool_calls")
|
||||
assert.Contains(t, s, "北京")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
sr := canonical.StopReasonEndTurn
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventMessageDelta,
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.NotEmpty(t, chunks)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "finish_reason")
|
||||
assert.Contains(t, s, "stop")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
usage := canonical.CanonicalUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
}
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventMessageDelta,
|
||||
Usage: &usage,
|
||||
}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.NotEmpty(t, chunks)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "usage")
|
||||
assert.Contains(t, s, "prompt_tokens")
|
||||
}
|
||||
245
backend/internal/conversion/openai/types.go
Normal file
245
backend/internal/conversion/openai/types.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package openai
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// ChatCompletionRequest OpenAI Chat Completion 请求
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
N *int `json:"n,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
Logprobs *bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
||||
|
||||
// 已废弃字段
|
||||
Functions []FunctionDef `json:"functions,omitempty"`
|
||||
FunctionCall any `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// Message OpenAI 消息
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
|
||||
// 已废弃
|
||||
FunctionCall *FunctionCallMsg `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCall OpenAI 工具调用
|
||||
type ToolCall struct {
|
||||
Index *int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function *FunctionCall `json:"function,omitempty"`
|
||||
Custom *CustomTool `json:"custom,omitempty"`
|
||||
}
|
||||
|
||||
// FunctionCall OpenAI 函数调用
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// CustomTool 自定义工具
|
||||
type CustomTool struct {
|
||||
Name string `json:"name"`
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
// FunctionCallMsg 已废弃的函数调用消息
|
||||
type FunctionCallMsg struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
// Tool OpenAI 工具定义
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function *FunctionDef `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
// FunctionDef OpenAI 函数定义
|
||||
type FunctionDef struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// ResponseFormat OpenAI 响应格式
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
// JSONSchemaDef JSON Schema 定义
|
||||
type JSONSchemaDef struct {
|
||||
Name string `json:"name"`
|
||||
Schema json.RawMessage `json:"schema,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// StreamOptions 流式选项
|
||||
type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
// ChatCompletionResponse OpenAI Chat Completion 响应
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// Choice OpenAI 选择项
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Delta *Message `json:"delta,omitempty"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Logprobs any `json:"logprobs,omitempty"`
|
||||
}
|
||||
|
||||
// Usage OpenAI 用量
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
|
||||
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
// PromptTokensDetails 提示 Token 详情
|
||||
type PromptTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||
AudioTokens int `json:"audio_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// CompletionTokensDetails 完成 Token 详情
|
||||
type CompletionTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
|
||||
AudioTokens int `json:"audio_tokens,omitempty"`
|
||||
AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"`
|
||||
RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// StreamChunk OpenAI 流式 chunk
|
||||
type StreamChunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
}
|
||||
|
||||
// ModelsResponse OpenAI 模型列表响应
|
||||
type ModelsResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []ModelItem `json:"data"`
|
||||
}
|
||||
|
||||
// ModelItem OpenAI 模型项
|
||||
type ModelItem struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
// ModelInfoResponse OpenAI 模型详情响应
|
||||
type ModelInfoResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
// EmbeddingRequest OpenAI 嵌入请求
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"`
|
||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||
Dimensions *int `json:"dimensions,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingResponse OpenAI 嵌入响应
|
||||
type EmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage EmbeddingUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// EmbeddingData 嵌入数据项
|
||||
type EmbeddingData struct {
|
||||
Index int `json:"index"`
|
||||
Embedding any `json:"embedding"`
|
||||
}
|
||||
|
||||
// EmbeddingUsage 嵌入用量
|
||||
type EmbeddingUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// RerankRequest OpenAI 重排序请求
|
||||
type RerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
}
|
||||
|
||||
// RerankResponse OpenAI 重排序响应
|
||||
type RerankResponse struct {
|
||||
Results []RerankResult `json:"results"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// RerankResult 重排序结果项
|
||||
type RerankResult struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
Document *string `json:"document,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorResponse OpenAI 错误响应
|
||||
type ErrorResponse struct {
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail 错误详情
|
||||
type ErrorDetail struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param any `json:"param"`
|
||||
Code string `json:"code"`
|
||||
}
|
||||
Reference in New Issue
Block a user