fix: 修正 conversion 代理路径和错误边界
This commit is contained in:
@@ -4,10 +4,10 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 支持 OpenAI 协议(`/openai/v1/...`)
|
||||
- 支持 OpenAI 协议(`/openai/...`,例如 `/openai/chat/completions`)
|
||||
- 支持 Anthropic 协议(`/anthropic/v1/...`)
|
||||
- 支持 Hub-and-Spoke 跨协议双向转换(OpenAI ↔ Anthropic)
|
||||
- 同协议透传(零语义损失、零序列化开销)
|
||||
- 同协议透传(跳过 Canonical 全量转换,保持协议语义)
|
||||
- 支持流式响应(SSE)
|
||||
- 支持 Function Calling / Tools
|
||||
- 支持 Thinking / Reasoning
|
||||
@@ -220,7 +220,7 @@ OpenAI Response ← Canonical Response ← Anthropic Response
|
||||
|
||||
### Smart Passthrough 机制
|
||||
|
||||
同协议请求走 Smart Passthrough 路径,**零序列化开销**:
|
||||
同协议请求走 Smart Passthrough 路径,不进入 Canonical 全量转换:
|
||||
|
||||
```
|
||||
1. 检测 clientProtocol == providerProtocol
|
||||
@@ -229,12 +229,14 @@ OpenAI Response ← Canonical Response ← Anthropic Response
|
||||
4. 响应中仅改写 model 字段:upstream_model_name → unified_id
|
||||
```
|
||||
|
||||
Smart Passthrough 保持未改写 JSON 字段的内容和类型不变,不承诺保留原始字节顺序、空白或对象字段顺序。
|
||||
|
||||
### 流式转换器层次
|
||||
|
||||
```
|
||||
StreamConverter (接口)
|
||||
├── PassthroughStreamConverter # 直接透传,无任何处理
|
||||
├── SmartPassthroughStreamConverter # 同协议 + 逐 chunk 改写 model
|
||||
├── SmartPassthroughStreamConverter # 同协议 + 按 SSE frame 改写 data JSON model
|
||||
└── CanonicalStreamConverter # 跨协议完整转换(decode → encode)
|
||||
```
|
||||
|
||||
@@ -301,6 +303,7 @@ StreamConverter (接口)
|
||||
| `PROTOCOL_CONSTRAINT_VIOLATION` | 协议约束违反 |
|
||||
| `ENCODING_FAILURE` | 编码失败 |
|
||||
| `INTERFACE_NOT_SUPPORTED` | 接口不支持(如 Anthropic Embeddings) |
|
||||
| `UNSUPPORTED_MULTIMODAL` | 跨协议暂不支持多模态内容 |
|
||||
|
||||
### AppError 预定义错误
|
||||
|
||||
@@ -460,7 +463,7 @@ goose -dir migrations sqlite3 ~/.nex/config.db up
|
||||
|
||||
### 代理接口
|
||||
|
||||
使用 `/{protocol}/v1/{path}` URL 前缀路由:
|
||||
使用 `/{protocol}/*path` URL 前缀路由。网关只剥离第一段协议前缀,不统一添加或移除 `/v1`;剩余 path 是协议原生 nativePath,由对应 adapter 识别和组合上游 URL。
|
||||
|
||||
#### OpenAI 协议
|
||||
|
||||
@@ -478,10 +481,20 @@ POST /anthropic/v1/messages
|
||||
GET /anthropic/v1/models
|
||||
```
|
||||
|
||||
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销。
|
||||
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传或 Smart Passthrough,跳过 Canonical 全量转换。
|
||||
|
||||
**统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。
|
||||
|
||||
**base_url 约定**:
|
||||
- OpenAI 供应商配置到版本路径一级,例如 `https://api.openai.com/v1`。
|
||||
- Anthropic 供应商配置到域名级,例如 `https://api.anthropic.com`。
|
||||
|
||||
**模型提取边界**:只有 adapter 明确适配的 Chat、Embeddings、Rerank 等接口会提取 `model` 并尝试统一模型 ID 路由。未知接口不做顶层 `model` 猜测,直接按无 model 透传。
|
||||
|
||||
**流式透传边界**:同协议无响应 model 改写时 raw passthrough,保留 SSE frame 边界和 `[DONE]`;同协议需要改写时按 SSE frame 解析 `data` JSON,仅改写 `model`;跨协议继续使用 StreamDecoder → CanonicalStreamConverter → StreamEncoder。
|
||||
|
||||
**错误边界**:网关层代理错误返回 `{"error":"...","code":"..."}`。已收到上游 HTTP 响应时,非 2xx status、过滤 hop-by-hop header 后的 headers 和 body 直接透传;没有收到上游响应的连接/DNS/TLS/超时错误返回 `UPSTREAM_UNAVAILABLE`。
|
||||
|
||||
### 管理接口
|
||||
|
||||
#### 供应商管理
|
||||
|
||||
@@ -49,6 +49,28 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_APIReferenceNativePaths(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
// docs/api_reference/anthropic defines messages and models under /v1.
|
||||
tests := []struct {
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"/v1/messages", conversion.InterfaceTypeChat},
|
||||
{"/v1/models", conversion.InterfaceTypeModels},
|
||||
{"/v1/models/claude-sonnet-4-5", conversion.InterfaceTypeModelInfo},
|
||||
{"/messages", conversion.InterfaceTypePassthrough},
|
||||
{"/models", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
|
||||
@@ -3,11 +3,13 @@ package conversion
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
@@ -71,7 +73,7 @@ func (e *ConversionEngine) IsPassthrough(clientProtocol, providerProtocol string
|
||||
|
||||
// ConvertHttpRequest 转换 HTTP 请求
|
||||
func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) {
|
||||
nativePath := spec.URL
|
||||
nativePath, rawQuery := splitRequestPath(spec.URL)
|
||||
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
providerAdapter, err := e.registry.Get(providerProtocol)
|
||||
@@ -96,8 +98,11 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
||||
}
|
||||
}
|
||||
|
||||
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||
providerURL = appendRawQuery(providerURL, rawQuery)
|
||||
|
||||
return &HTTPRequestSpec{
|
||||
URL: provider.BaseURL + nativePath,
|
||||
URL: joinBaseURL(provider.BaseURL, providerURL),
|
||||
Method: spec.Method,
|
||||
Headers: providerAdapter.BuildHeaders(provider),
|
||||
Body: rewrittenBody,
|
||||
@@ -115,6 +120,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
||||
|
||||
interfaceType := clientAdapter.DetectInterfaceType(nativePath)
|
||||
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||
providerURL = appendRawQuery(providerURL, rawQuery)
|
||||
providerHeaders := providerAdapter.BuildHeaders(provider)
|
||||
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
|
||||
if err != nil {
|
||||
@@ -122,7 +128,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
||||
}
|
||||
|
||||
return &HTTPRequestSpec{
|
||||
URL: provider.BaseURL + providerURL,
|
||||
URL: joinBaseURL(provider.BaseURL, providerURL),
|
||||
Method: spec.Method,
|
||||
Headers: providerHeaders,
|
||||
Body: providerBody,
|
||||
@@ -198,7 +204,7 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
|
||||
|
||||
ctx := ConversionContext{
|
||||
ConversionID: uuid.New().String(),
|
||||
InterfaceType: InterfaceTypeChat,
|
||||
InterfaceType: interfaceType,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
@@ -268,7 +274,7 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie
|
||||
func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
canonicalReq, err := clientAdapter.DecodeRequest(body)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeJSONParseError, "解码请求失败").WithCause(err)
|
||||
return nil, NewRequestJSONParseError("解码请求失败", err)
|
||||
}
|
||||
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
@@ -276,6 +282,9 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if containsUnsupportedMultimodal(canonicalReq) {
|
||||
return nil, NewConversionError(ErrorCodeUnsupportedMultimodal, "跨协议暂不支持多模态内容")
|
||||
}
|
||||
|
||||
encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider)
|
||||
if err != nil {
|
||||
@@ -287,7 +296,7 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
|
||||
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
canonicalResp, err := providerAdapter.DecodeResponse(body)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
|
||||
return nil, NewResponseJSONParseError("解码响应失败", err)
|
||||
}
|
||||
if modelOverride != "" {
|
||||
canonicalResp.Model = modelOverride
|
||||
@@ -375,6 +384,7 @@ func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string
|
||||
if err != nil {
|
||||
return InterfaceTypePassthrough, err
|
||||
}
|
||||
nativePath, _ = splitRequestPath(nativePath)
|
||||
return adapter.DetectInterfaceType(nativePath), nil
|
||||
}
|
||||
|
||||
@@ -398,3 +408,46 @@ func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol stri
|
||||
body, statusCode := adapter.EncodeError(err)
|
||||
return body, statusCode, nil
|
||||
}
|
||||
|
||||
func splitRequestPath(rawPath string) (string, string) {
|
||||
path, query, found := strings.Cut(rawPath, "?")
|
||||
if !found {
|
||||
return rawPath, ""
|
||||
}
|
||||
return path, query
|
||||
}
|
||||
|
||||
func appendRawQuery(path, rawQuery string) string {
|
||||
if rawQuery == "" {
|
||||
return path
|
||||
}
|
||||
if strings.Contains(path, "?") {
|
||||
return path + "&" + rawQuery
|
||||
}
|
||||
return path + "?" + rawQuery
|
||||
}
|
||||
|
||||
func joinBaseURL(baseURL, path string) string {
|
||||
if baseURL == "" {
|
||||
return path
|
||||
}
|
||||
if path == "" {
|
||||
return baseURL
|
||||
}
|
||||
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
|
||||
}
|
||||
|
||||
func containsUnsupportedMultimodal(req *canonical.CanonicalRequest) bool {
|
||||
if req == nil {
|
||||
return false
|
||||
}
|
||||
for _, msg := range req.Messages {
|
||||
for _, block := range msg.Content {
|
||||
switch block.Type {
|
||||
case "image", "audio", "video", "file":
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
63
backend/internal/conversion/engine_adapter_test.go
Normal file
63
backend/internal/conversion/engine_adapter_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package conversion_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/conversion/openai"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestConvertHttpRequest_SameProtocolUsesAdapterBuildURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
adapter conversion.ProtocolAdapter
|
||||
clientProtocol string
|
||||
providerProtocol string
|
||||
baseURL string
|
||||
nativePath string
|
||||
expectedURL string
|
||||
body []byte
|
||||
}{
|
||||
{
|
||||
name: "openai base url includes version path",
|
||||
adapter: openai.NewAdapter(),
|
||||
clientProtocol: "openai",
|
||||
providerProtocol: "openai",
|
||||
baseURL: "http://example.com/v1",
|
||||
nativePath: "/chat/completions",
|
||||
expectedURL: "http://example.com/v1/chat/completions",
|
||||
body: []byte(`{"model":"gpt-4","messages":[]}`),
|
||||
},
|
||||
{
|
||||
name: "anthropic native path keeps v1",
|
||||
adapter: anthropic.NewAdapter(),
|
||||
clientProtocol: "anthropic",
|
||||
providerProtocol: "anthropic",
|
||||
baseURL: "http://example.com",
|
||||
nativePath: "/v1/messages",
|
||||
expectedURL: "http://example.com/v1/messages",
|
||||
body: []byte(`{"model":"claude","messages":[]}`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
engine := conversion.NewConversionEngine(registry, zap.NewNop())
|
||||
require.NoError(t, registry.Register(tt.adapter))
|
||||
|
||||
out, err := engine.ConvertHttpRequest(conversion.HTTPRequestSpec{
|
||||
URL: tt.nativePath,
|
||||
Method: "POST",
|
||||
Body: tt.body,
|
||||
}, tt.clientProtocol, tt.providerProtocol, conversion.NewTargetProvider(tt.baseURL, "key", "upstream-model"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedURL, out.URL)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package conversion
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
@@ -498,12 +499,13 @@ func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
|
||||
_, ok := converter.(*SmartPassthroughStreamConverter)
|
||||
assert.True(t, ok)
|
||||
|
||||
// 验证 chunk 改写
|
||||
chunks := converter.ProcessChunk([]byte(`{"model":"gpt-4","choices":[]}`))
|
||||
// 验证 SSE frame 中的 data JSON 被改写
|
||||
chunks := converter.ProcessChunk([]byte(`data: {"model":"gpt-4","choices":[]}` + "\n\n"))
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
var resp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(chunks[0], &resp))
|
||||
payload := strings.TrimPrefix(strings.TrimSpace(string(chunks[0])), "data: ")
|
||||
require.NoError(t, json.Unmarshal([]byte(payload), &resp))
|
||||
assert.Equal(t, "openai/gpt-4", resp["model"])
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,13 @@ const (
|
||||
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
|
||||
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
|
||||
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
|
||||
ErrorCodeUnsupportedMultimodal ErrorCode = "UNSUPPORTED_MULTIMODAL"
|
||||
)
|
||||
|
||||
const (
|
||||
ErrorDetailPhase = "phase"
|
||||
ErrorPhaseRequest = "request"
|
||||
ErrorPhaseResponse = "response"
|
||||
)
|
||||
|
||||
// ConversionError 协议转换错误
|
||||
@@ -39,6 +46,20 @@ func NewConversionError(code ErrorCode, message string) *ConversionError {
|
||||
}
|
||||
}
|
||||
|
||||
// NewRequestJSONParseError 创建请求 JSON 解析错误。
|
||||
func NewRequestJSONParseError(message string, cause error) *ConversionError {
|
||||
return NewConversionError(ErrorCodeJSONParseError, message).
|
||||
WithDetail(ErrorDetailPhase, ErrorPhaseRequest).
|
||||
WithCause(cause)
|
||||
}
|
||||
|
||||
// NewResponseJSONParseError 创建响应 JSON 解析错误。
|
||||
func NewResponseJSONParseError(message string, cause error) *ConversionError {
|
||||
return NewConversionError(ErrorCodeJSONParseError, message).
|
||||
WithDetail(ErrorDetailPhase, ErrorPhaseResponse).
|
||||
WithCause(cause)
|
||||
}
|
||||
|
||||
// WithClientProtocol 设置客户端协议
|
||||
func (e *ConversionError) WithClientProtocol(protocol string) *ConversionError {
|
||||
e.ClientProtocol = protocol
|
||||
|
||||
@@ -44,6 +44,29 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_APIReferenceNativePaths(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
// docs/api_reference/openai, excluding responses, defines paths without /v1.
|
||||
tests := []struct {
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"/chat/completions", conversion.InterfaceTypeChat},
|
||||
{"/models", conversion.InterfaceTypeModels},
|
||||
{"/models/gpt-4.1", conversion.InterfaceTypeModelInfo},
|
||||
{"/embeddings", conversion.InterfaceTypeEmbeddings},
|
||||
{"/rerank", conversion.InterfaceTypeRerank},
|
||||
{"/v1/chat/completions", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package conversion
|
||||
|
||||
import "nex/backend/internal/conversion/canonical"
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamDecoder 流式解码器接口
|
||||
type StreamDecoder interface {
|
||||
@@ -39,11 +44,12 @@ func (c *PassthroughStreamConverter) Flush() [][]byte {
|
||||
}
|
||||
|
||||
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
|
||||
// 逐 chunk 改写 model 字段
|
||||
// 按 SSE frame 改写 data JSON 中的 model 字段
|
||||
type SmartPassthroughStreamConverter struct {
|
||||
adapter ProtocolAdapter
|
||||
modelOverride string
|
||||
interfaceType InterfaceType
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
|
||||
@@ -55,24 +61,45 @@ func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride s
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 改写 chunk 中的 model 字段
|
||||
// ProcessChunk 按 SSE frame 改写 data JSON 中的 model 字段
|
||||
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
if len(rawChunk) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
rewrittenChunk, err := c.adapter.RewriteResponseModelName(rawChunk, c.modelOverride, c.interfaceType)
|
||||
if err != nil {
|
||||
// 改写失败,返回原始 chunk
|
||||
return [][]byte{rawChunk}
|
||||
}
|
||||
c.buffer = append(c.buffer, rawChunk...)
|
||||
frames, rest := splitSSEFrames(c.buffer)
|
||||
c.buffer = rest
|
||||
|
||||
return [][]byte{rewrittenChunk}
|
||||
result := make([][]byte, 0, len(frames))
|
||||
for _, frame := range frames {
|
||||
result = append(result, c.rewriteFrame(frame))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Flush 无缓冲数据
|
||||
func (c *SmartPassthroughStreamConverter) rewriteFrame(frame []byte) []byte {
|
||||
payload, ok := sseFrameDataPayload(frame)
|
||||
if !ok || strings.TrimSpace(payload) == "[DONE]" {
|
||||
return frame
|
||||
}
|
||||
|
||||
rewrittenPayload, err := c.adapter.RewriteResponseModelName([]byte(payload), c.modelOverride, c.interfaceType)
|
||||
if err != nil {
|
||||
return frame
|
||||
}
|
||||
|
||||
return rebuildSSEFrameWithData(frame, string(rewrittenPayload))
|
||||
}
|
||||
|
||||
// Flush 输出未形成完整 frame 的剩余数据
|
||||
func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
|
||||
return nil
|
||||
if len(c.buffer) == 0 {
|
||||
return nil
|
||||
}
|
||||
frame := append([]byte(nil), c.buffer...)
|
||||
c.buffer = nil
|
||||
return [][]byte{c.rewriteFrame(frame)}
|
||||
}
|
||||
|
||||
// CanonicalStreamConverter 跨协议规范流式转换器
|
||||
@@ -153,3 +180,86 @@ func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.Canonical
|
||||
event.Message.Model = c.modelOverride
|
||||
}
|
||||
}
|
||||
|
||||
func splitSSEFrames(data []byte) ([][]byte, []byte) {
|
||||
var frames [][]byte
|
||||
for len(data) > 0 {
|
||||
idx, sepLen := findSSEFrameSeparator(data)
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
end := idx + sepLen
|
||||
frames = append(frames, append([]byte(nil), data[:end]...))
|
||||
data = data[end:]
|
||||
}
|
||||
return frames, data
|
||||
}
|
||||
|
||||
func findSSEFrameSeparator(data []byte) (int, int) {
|
||||
lf := bytes.Index(data, []byte("\n\n"))
|
||||
crlf := bytes.Index(data, []byte("\r\n\r\n"))
|
||||
switch {
|
||||
case lf < 0 && crlf < 0:
|
||||
return -1, 0
|
||||
case lf < 0:
|
||||
return crlf, 4
|
||||
case crlf < 0:
|
||||
return lf, 2
|
||||
case crlf <= lf:
|
||||
return crlf, 4
|
||||
default:
|
||||
return lf, 2
|
||||
}
|
||||
}
|
||||
|
||||
func sseFrameDataPayload(frame []byte) (string, bool) {
|
||||
text := strings.TrimRight(string(frame), "\r\n")
|
||||
lines := strings.Split(text, "\n")
|
||||
var dataLines []string
|
||||
for _, line := range lines {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
value := strings.TrimPrefix(line, "data:")
|
||||
if strings.HasPrefix(value, " ") {
|
||||
value = value[1:]
|
||||
}
|
||||
dataLines = append(dataLines, value)
|
||||
}
|
||||
}
|
||||
if len(dataLines) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return strings.Join(dataLines, "\n"), true
|
||||
}
|
||||
|
||||
func rebuildSSEFrameWithData(frame []byte, data string) []byte {
|
||||
lineEnding, separator := sseLineEnding(frame)
|
||||
text := strings.TrimRight(string(frame), "\r\n")
|
||||
lines := strings.Split(text, "\n")
|
||||
out := make([]string, 0, len(lines)+1)
|
||||
dataWritten := false
|
||||
for _, line := range lines {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
if !dataWritten {
|
||||
for _, dataLine := range strings.Split(data, "\n") {
|
||||
out = append(out, "data: "+dataLine)
|
||||
}
|
||||
dataWritten = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
out = append(out, line)
|
||||
}
|
||||
if !dataWritten {
|
||||
out = append(out, "data: "+data)
|
||||
}
|
||||
return []byte(strings.Join(out, lineEnding) + separator)
|
||||
}
|
||||
|
||||
func sseLineEnding(frame []byte) (string, string) {
|
||||
if bytes.Contains(frame, []byte("\r\n")) {
|
||||
return "\r\n", "\r\n\r\n"
|
||||
}
|
||||
return "\n", "\n\n"
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/service"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
"nex/backend/pkg/modelid"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -48,7 +49,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
// 从 URL 提取 clientProtocol: /{protocol}/v1/...
|
||||
clientProtocol := c.Param("protocol")
|
||||
if clientProtocol == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"})
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "缺少协议前缀")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -58,12 +59,13 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
path = "/" + path
|
||||
}
|
||||
nativePath := path
|
||||
requestPath := appendRawQuery(nativePath, c.Request.URL.RawQuery)
|
||||
|
||||
// 获取 client adapter
|
||||
registry := h.engine.GetRegistry()
|
||||
clientAdapter, err := registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
|
||||
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -80,7 +82,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
if ifaceType == conversion.InterfaceTypeModelInfo {
|
||||
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的模型 ID 格式"})
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的模型 ID 格式")
|
||||
return
|
||||
}
|
||||
h.handleModelInfo(c, unifiedID, clientAdapter)
|
||||
@@ -90,40 +92,50 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "读取请求体失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析统一模型 ID(使用 adapter.ExtractModelName)
|
||||
var providerID, modelName string
|
||||
if len(body) > 0 {
|
||||
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
|
||||
if err == nil && unifiedID != "" {
|
||||
pid, mn, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||
if err == nil {
|
||||
providerID = pid
|
||||
modelName = mn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 构建输入 HTTPRequestSpec
|
||||
inSpec := conversion.HTTPRequestSpec{
|
||||
URL: nativePath,
|
||||
URL: requestPath,
|
||||
Method: c.Request.Method,
|
||||
Headers: extractHeaders(c),
|
||||
Body: body,
|
||||
}
|
||||
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
|
||||
|
||||
// 只有 adapter 明确适配的接口才提取 model。未知接口不做通用 model 猜测。
|
||||
if len(body) == 0 || !supportsModelExtraction(ifaceType) {
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
|
||||
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
|
||||
if err != nil {
|
||||
if isInvalidJSONError(err) {
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误")
|
||||
return
|
||||
}
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
|
||||
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||
if err != nil {
|
||||
// 原始模型名兼容透传:非统一模型 ID 不参与路由。
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
if providerID == "" || modelName == "" {
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
|
||||
// 路由
|
||||
routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
|
||||
if err != nil {
|
||||
// GET 请求或无法提取 model 时,直接转发到上游
|
||||
if len(body) == 0 || modelName == "" {
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol)
|
||||
return
|
||||
}
|
||||
h.writeError(c, err, clientProtocol)
|
||||
h.writeRouteError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -143,9 +155,6 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
routeResult.Model.ModelName, // 上游模型名,用于请求改写
|
||||
)
|
||||
|
||||
// 判断是否流式
|
||||
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
|
||||
|
||||
// 计算统一模型 ID(用于响应覆写)
|
||||
unifiedModelID := routeResult.Model.UnifiedModelID()
|
||||
|
||||
@@ -156,6 +165,28 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func supportsModelExtraction(ifaceType conversion.InterfaceType) bool {
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isInvalidJSONError(err error) bool {
|
||||
var syntaxErr *json.SyntaxError
|
||||
var typeErr *json.UnmarshalTypeError
|
||||
return errors.As(err, &syntaxErr) || errors.As(err, &typeErr)
|
||||
}
|
||||
|
||||
func appendRawQuery(path, rawQuery string) string {
|
||||
if rawQuery == "" {
|
||||
return path
|
||||
}
|
||||
return path + "?" + rawQuery
|
||||
}
|
||||
|
||||
// handleNonStream 处理非流式请求
|
||||
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
|
||||
// 转换请求
|
||||
@@ -170,7 +201,11 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
|
||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.logger.Error("发送请求失败", zap.Error(err))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, *resp)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -182,15 +217,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
for k, v := range convertedResp.Headers {
|
||||
c.Header(k, v)
|
||||
}
|
||||
if c.GetHeader("Content-Type") == "" {
|
||||
c.Header("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
||||
h.writeConvertedResponse(c, *convertedResp)
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||||
@@ -206,15 +233,23 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
|
||||
return
|
||||
}
|
||||
|
||||
// 创建流式转换器,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
|
||||
// 发送流式请求
|
||||
streamResp, err := h.client.SendStream(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
|
||||
StatusCode: streamResp.StatusCode,
|
||||
Headers: streamResp.Headers,
|
||||
Body: streamResp.Body,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 发送流式请求
|
||||
eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec)
|
||||
// 创建流式转换器,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
@@ -225,8 +260,9 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
flushed := false
|
||||
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
if event.Error != nil {
|
||||
h.logger.Error("流读取错误", zap.Error(event.Error))
|
||||
break
|
||||
@@ -237,6 +273,7 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
flushed = true
|
||||
break
|
||||
}
|
||||
|
||||
@@ -246,6 +283,12 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
|
||||
break
|
||||
}
|
||||
}
|
||||
if !flushed {
|
||||
chunks := streamConverter.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||||
@@ -291,7 +334,7 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
|
||||
models, err := h.providerService.ListEnabledModels()
|
||||
if err != nil {
|
||||
h.logger.Error("查询启用模型失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "查询模型失败")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -313,7 +356,7 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
|
||||
body, err := adapter.EncodeModelsResponse(modelList)
|
||||
if err != nil {
|
||||
h.logger.Error("编码 Models 响应失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -325,17 +368,14 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
|
||||
// 解析统一模型 ID
|
||||
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "无效的统一模型 ID 格式",
|
||||
"code": "INVALID_MODEL_ID",
|
||||
})
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的统一模型 ID 格式")
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库查询模型
|
||||
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "模型未找到"})
|
||||
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", "模型未找到")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -351,46 +391,103 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
|
||||
body, err := adapter.EncodeModelInfoResponse(modelInfo)
|
||||
if err != nil {
|
||||
h.logger.Error("编码 ModelInfo 响应失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "application/json", body)
|
||||
}
|
||||
|
||||
// writeConversionError 写入转换错误
|
||||
// writeConversionError 写入网关层转换错误
|
||||
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
|
||||
var convErr *conversion.ConversionError
|
||||
if errors.As(err, &convErr) {
|
||||
body, statusCode, encodeErr := h.engine.EncodeError(convErr, clientProtocol)
|
||||
if encodeErr != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": encodeErr.Error()})
|
||||
return
|
||||
}
|
||||
c.Data(statusCode, "application/json", body)
|
||||
statusCode, code, message := mapConversionError(convErr)
|
||||
h.writeProxyError(c, statusCode, code, message)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", err.Error())
|
||||
}
|
||||
|
||||
// writeError 写入路由错误
|
||||
func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
func mapConversionError(err *conversion.ConversionError) (int, string, string) {
|
||||
switch err.Code {
|
||||
case conversion.ErrorCodeJSONParseError:
|
||||
if phase, ok := err.Details[conversion.ErrorDetailPhase].(string); ok && phase == conversion.ErrorPhaseRequest {
|
||||
return http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误"
|
||||
}
|
||||
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
|
||||
case conversion.ErrorCodeInvalidInput,
|
||||
conversion.ErrorCodeMissingRequiredField,
|
||||
conversion.ErrorCodeProtocolConstraint:
|
||||
return http.StatusBadRequest, "INVALID_REQUEST", err.Message
|
||||
case conversion.ErrorCodeInterfaceNotSupported:
|
||||
return http.StatusBadRequest, "UNSUPPORTED_INTERFACE", err.Message
|
||||
case conversion.ErrorCodeUnsupportedMultimodal:
|
||||
return http.StatusBadRequest, "UNSUPPORTED_MULTIMODAL", err.Message
|
||||
default:
|
||||
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeRouteError(c *gin.Context, err error) {
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
switch appErr.Code {
|
||||
case appErrors.ErrModelNotFound.Code, appErrors.ErrModelDisabled.Code:
|
||||
h.writeProxyError(c, appErr.HTTPStatus, "MODEL_NOT_FOUND", appErr.Message)
|
||||
case appErrors.ErrProviderNotFound.Code, appErrors.ErrProviderDisabled.Code:
|
||||
h.writeProxyError(c, appErr.HTTPStatus, "PROVIDER_NOT_FOUND", appErr.Message)
|
||||
default:
|
||||
h.writeProxyError(c, appErr.HTTPStatus, "INVALID_REQUEST", appErr.Message)
|
||||
}
|
||||
return
|
||||
}
|
||||
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", err.Error())
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeUpstreamUnavailable(c *gin.Context, err error) {
|
||||
h.logger.Error("上游不可达", zap.Error(err))
|
||||
h.writeProxyError(c, http.StatusBadGateway, "UPSTREAM_UNAVAILABLE", "上游服务不可达")
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeProxyError(c *gin.Context, status int, code, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": message,
|
||||
"code": code,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeConvertedResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
|
||||
for k, v := range resp.Headers {
|
||||
c.Header(k, v)
|
||||
}
|
||||
contentType := headerValue(resp.Headers, "Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, resp.Body)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeUpstreamResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
|
||||
for k, v := range filterHopByHopHeaders(resp.Headers) {
|
||||
c.Header(k, v)
|
||||
}
|
||||
contentType := headerValue(resp.Headers, "Content-Type")
|
||||
c.Data(resp.StatusCode, contentType, resp.Body)
|
||||
}
|
||||
|
||||
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
|
||||
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) {
|
||||
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string, ifaceType conversion.InterfaceType, isStream bool) {
|
||||
registry := h.engine.GetRegistry()
|
||||
adapter, err := registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
|
||||
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.providerService.List()
|
||||
if err != nil || len(providers) == 0 {
|
||||
h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"})
|
||||
h.logger.Warn("无可用供应商转发请求", zap.String("path", inSpec.URL))
|
||||
h.writeProxyError(c, http.StatusNotFound, "PROVIDER_NOT_FOUND", "没有可用的供应商。请先创建供应商和模型。")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -400,19 +497,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
|
||||
providerProtocol = "openai"
|
||||
}
|
||||
|
||||
ifaceType := adapter.DetectInterfaceType(inSpec.URL)
|
||||
|
||||
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
|
||||
|
||||
var outSpec *conversion.HTTPRequestSpec
|
||||
if clientProtocol == providerProtocol {
|
||||
upstreamURL := p.BaseURL + inSpec.URL
|
||||
upstreamPath := adapter.BuildUrl(stripRawQuery(inSpec.URL), ifaceType)
|
||||
upstreamPath = appendRawQuery(upstreamPath, rawQueryFromPath(inSpec.URL))
|
||||
headers := adapter.BuildHeaders(targetProvider)
|
||||
if _, ok := headers["Content-Type"]; !ok {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
outSpec = &conversion.HTTPRequestSpec{
|
||||
URL: upstreamURL,
|
||||
URL: joinBaseURL(p.BaseURL, upstreamPath),
|
||||
Method: inSpec.Method,
|
||||
Headers: headers,
|
||||
Body: inSpec.Body,
|
||||
@@ -425,9 +521,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
|
||||
}
|
||||
}
|
||||
|
||||
if isStream {
|
||||
h.forwardStream(c, *outSpec, clientProtocol, providerProtocol, ifaceType)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, *resp)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -437,13 +542,111 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range convertedResp.Headers {
|
||||
c.Header(k, v)
|
||||
h.writeConvertedResponse(c, *convertedResp)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) forwardStream(c *gin.Context, outSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, ifaceType conversion.InterfaceType) {
|
||||
streamResp, err := h.client.SendStream(c.Request.Context(), outSpec)
|
||||
if err != nil {
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if c.GetHeader("Content-Type") == "" {
|
||||
c.Header("Content-Type", "application/json")
|
||||
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
|
||||
StatusCode: streamResp.StatusCode,
|
||||
Headers: streamResp.Headers,
|
||||
Body: streamResp.Body,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
||||
|
||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, "", ifaceType)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
flushed := false
|
||||
for event := range streamResp.Events {
|
||||
if event.Error != nil {
|
||||
h.logger.Error("透传流读取错误", zap.Error(event.Error))
|
||||
break
|
||||
}
|
||||
if event.Done {
|
||||
chunks := streamConverter.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
flushed = true
|
||||
break
|
||||
}
|
||||
chunks := streamConverter.ProcessChunk(event.Data)
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||||
break
|
||||
}
|
||||
}
|
||||
if !flushed {
|
||||
chunks := streamConverter.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stripRawQuery(path string) string {
|
||||
pathOnly, _, _ := strings.Cut(path, "?")
|
||||
return pathOnly
|
||||
}
|
||||
|
||||
func rawQueryFromPath(path string) string {
|
||||
_, rawQuery, found := strings.Cut(path, "?")
|
||||
if !found {
|
||||
return ""
|
||||
}
|
||||
return rawQuery
|
||||
}
|
||||
|
||||
func joinBaseURL(baseURL, path string) string {
|
||||
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
|
||||
}
|
||||
|
||||
func headerValue(headers map[string]string, key string) string {
|
||||
for k, v := range headers {
|
||||
if strings.EqualFold(k, key) {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func filterHopByHopHeaders(headers map[string]string) map[string]string {
|
||||
if len(headers) == 0 {
|
||||
return nil
|
||||
}
|
||||
hopByHop := map[string]struct{}{
|
||||
"connection": {},
|
||||
"transfer-encoding": {},
|
||||
"keep-alive": {},
|
||||
"proxy-authenticate": {},
|
||||
"proxy-authorization": {},
|
||||
"te": {},
|
||||
"trailer": {},
|
||||
"upgrade": {},
|
||||
}
|
||||
filtered := make(map[string]string, len(headers))
|
||||
for k, v := range headers {
|
||||
if _, skip := hopByHop[strings.ToLower(k)]; skip {
|
||||
continue
|
||||
}
|
||||
filtered[k] = v
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// extractHeaders 从 Gin context 提取请求头
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
@@ -73,7 +74,7 @@ func TestProxyHandler_HandleProxy_NonStreamSuccess(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
@@ -93,7 +94,7 @@ func TestProxyHandler_HandleProxy_NonStreamSuccess(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -109,9 +110,8 @@ func TestProxyHandler_HandleProxy_RoutingError_WithBody(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(nil, appErrors.ErrModelNotFound)
|
||||
routingSvc.EXPECT().RouteByModelName("unknown", "model").Return(nil, appErrors.ErrModelNotFound)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
providerSvc.EXPECT().List().Return(nil, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
@@ -119,10 +119,11 @@ func TestProxyHandler_HandleProxy_RoutingError_WithBody(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"unknown","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "MODEL_NOT_FOUND")
|
||||
}
|
||||
|
||||
func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
|
||||
@@ -131,7 +132,7 @@ func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
@@ -146,10 +147,11 @@ func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
assert.Equal(t, 502, w.Code)
|
||||
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
|
||||
@@ -158,7 +160,7 @@ func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
@@ -173,10 +175,11 @@ func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
assert.Equal(t, 502, w.Code)
|
||||
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
|
||||
@@ -185,12 +188,12 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||
ch := make(chan provider.StreamEvent, 10)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
@@ -199,7 +202,7 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
|
||||
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
|
||||
ch <- provider.StreamEvent{Done: true}
|
||||
}()
|
||||
return ch, nil
|
||||
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
|
||||
})
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
@@ -209,12 +212,13 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
|
||||
assert.Contains(t, w.Body.String(), "Hello")
|
||||
assert.Contains(t, w.Body.String(), "p1/gpt-4")
|
||||
}
|
||||
|
||||
func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
|
||||
@@ -223,12 +227,12 @@ func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||
return nil, context.DeadlineExceeded
|
||||
})
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
@@ -238,10 +242,11 @@ func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
assert.Equal(t, 502, w.Code)
|
||||
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) {
|
||||
@@ -286,7 +291,7 @@ func TestProxyHandler_ForwardPassthrough_UnsupportedProtocol(t *testing.T) {
|
||||
c.Request = httptest.NewRequest("GET", "/unknown/models", nil)
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
|
||||
@@ -329,7 +334,7 @@ func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
@@ -348,7 +353,7 @@ func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -371,6 +376,7 @@ func TestProxyHandler_WriteConversionError_NonConversionError(t *testing.T) {
|
||||
|
||||
h.writeConversionError(c, context.DeadlineExceeded, "openai")
|
||||
assert.Equal(t, 500, w.Code)
|
||||
assert.JSONEq(t, `{"error":"context deadline exceeded","code":"CONVERSION_FAILED"}`, w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
|
||||
@@ -390,7 +396,40 @@ func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
|
||||
|
||||
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "bad request")
|
||||
h.writeConversionError(c, convErr, "openai")
|
||||
assert.Equal(t, 500, w.Code)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.JSONEq(t, `{"error":"bad request","code":"INVALID_REQUEST"}`, w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyHandler_WriteConversionError_JSONPhase(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
t.Run("request json parse error", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
|
||||
h.writeConversionError(c, conversion.NewRequestJSONParseError("解码请求失败", context.Canceled), "openai")
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.JSONEq(t, `{"error":"请求体 JSON 格式错误","code":"INVALID_JSON"}`, w.Body.String())
|
||||
})
|
||||
|
||||
t.Run("response json parse error", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
|
||||
h.writeConversionError(c, conversion.NewResponseJSONParseError("解码响应失败", context.Canceled), "openai")
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
assert.JSONEq(t, `{"error":"解码响应失败","code":"CONVERSION_FAILED"}`, w.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) {
|
||||
@@ -423,19 +462,19 @@ func TestProxyHandler_HandleStream_MidStreamError(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||
ch := make(chan provider.StreamEvent, 10)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ch <- provider.StreamEvent{Data: []byte("data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n")}
|
||||
ch <- provider.StreamEvent{Error: fmt.Errorf("connection reset by peer")}
|
||||
}()
|
||||
return ch, nil
|
||||
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
|
||||
})
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
@@ -445,7 +484,7 @@ func TestProxyHandler_HandleStream_MidStreamError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -460,12 +499,12 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||
ch := make(chan provider.StreamEvent, 10)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
@@ -473,7 +512,7 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
|
||||
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
|
||||
ch <- provider.StreamEvent{Done: true}
|
||||
}()
|
||||
return ch, nil
|
||||
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
|
||||
})
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
@@ -483,7 +522,7 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -505,7 +544,7 @@ func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
@@ -517,7 +556,7 @@ func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
@@ -532,7 +571,7 @@ func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) {
|
||||
require.NoError(t, registry.Register(openai.NewAdapter()))
|
||||
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
@@ -544,7 +583,7 @@ func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
@@ -560,7 +599,7 @@ func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) {
|
||||
require.NoError(t, registry.Register(anthropic.NewAdapter()))
|
||||
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "anthropic", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "claude-3", Enabled: true},
|
||||
}, nil)
|
||||
@@ -579,7 +618,7 @@ func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"claude-3","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
@@ -591,7 +630,7 @@ func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
@@ -611,7 +650,7 @@ func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -720,8 +759,9 @@ func TestProxyHandler_WriteError_RouteError(t *testing.T) {
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
|
||||
h.writeError(c, fmt.Errorf("model not found"), "openai")
|
||||
h.writeRouteError(c, fmt.Errorf("model not found"))
|
||||
assert.Equal(t, 404, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "MODEL_NOT_FOUND")
|
||||
}
|
||||
|
||||
func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
|
||||
@@ -994,7 +1034,7 @@ func TestProxyHandler_HandleProxy_CrossProtocol_Stream_UnifiedID(t *testing.T) {
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||
ch := make(chan provider.StreamEvent, 10)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
@@ -1012,7 +1052,7 @@ data: {"type":"message_stop"}
|
||||
`)}
|
||||
ch <- provider.StreamEvent{Done: true}
|
||||
}()
|
||||
return ch, nil
|
||||
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
|
||||
})
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
@@ -1100,3 +1140,314 @@ func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||
assert.Contains(t, resp, "error")
|
||||
}
|
||||
|
||||
func TestProxyHandler_HandleProxy_OpenAIAndAnthropicNativePaths(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protocol string
|
||||
path string
|
||||
requestPath string
|
||||
baseURL string
|
||||
expectedURL string
|
||||
body string
|
||||
responseBody string
|
||||
responseModel string
|
||||
}{
|
||||
{
|
||||
name: "openai path has no v1 after gateway prefix",
|
||||
protocol: "openai",
|
||||
path: "/chat/completions",
|
||||
requestPath: "/openai/chat/completions",
|
||||
baseURL: "https://api.test.com/v1",
|
||||
expectedURL: "https://api.test.com/v1/chat/completions",
|
||||
body: `{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`,
|
||||
responseBody: `{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`,
|
||||
responseModel: "p1/gpt-4",
|
||||
},
|
||||
{
|
||||
name: "anthropic path keeps v1 after gateway prefix",
|
||||
protocol: "anthropic",
|
||||
path: "/v1/messages",
|
||||
requestPath: "/anthropic/v1/messages",
|
||||
baseURL: "https://api.anthropic.test",
|
||||
expectedURL: "https://api.anthropic.test/v1/messages",
|
||||
body: `{"model":"p1/gpt-4","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`,
|
||||
responseBody: `{"id":"msg-1","type":"message","role":"assistant","model":"gpt-4","content":[{"type":"text","text":"ok"}],"stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`,
|
||||
responseModel: "p1/gpt-4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: tt.baseURL, Protocol: tt.protocol, Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||
assert.Equal(t, tt.expectedURL, spec.URL)
|
||||
return &conversion.HTTPResponseSpec{
|
||||
StatusCode: http.StatusOK,
|
||||
Headers: map[string]string{"Content-Type": "application/json"},
|
||||
Body: []byte(tt.responseBody),
|
||||
}, nil
|
||||
})
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
statsSvc.EXPECT().Record(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: tt.protocol}, {Key: "path", Value: tt.path}}
|
||||
c.Request = httptest.NewRequest("POST", tt.requestPath, bytes.NewReader([]byte(tt.body)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), tt.responseModel)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHandler_UpstreamNon2xx_Passthrough(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().Send(gomock.Any(), gomock.Any()).Return(&conversion.HTTPResponseSpec{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"X-Upstream-Error": "rate-limit",
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
Body: []byte(`{"error":{"message":"rate limited"}}`),
|
||||
}, nil)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
assert.JSONEq(t, `{"error":{"message":"rate limited"}}`, w.Body.String())
|
||||
assert.Equal(t, "rate-limit", w.Header().Get("X-Upstream-Error"))
|
||||
assert.Empty(t, w.Header().Get("Transfer-Encoding"))
|
||||
}
|
||||
|
||||
func TestProxyHandler_StreamUpstreamNon2xx_Passthrough(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).Return(&provider.StreamResponse{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Headers: map[string]string{"Content-Type": "application/json", "Connection": "close"},
|
||||
Body: []byte(`{"error":"upstream down"}`),
|
||||
}, nil)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
assert.JSONEq(t, `{"error":"upstream down"}`, w.Body.String())
|
||||
assert.Empty(t, w.Header().Get("Connection"))
|
||||
}
|
||||
|
||||
func TestFilterHopByHopHeaders(t *testing.T) {
|
||||
filtered := filterHopByHopHeaders(map[string]string{
|
||||
"Connection": "close",
|
||||
"Transfer-Encoding": "chunked",
|
||||
"Keep-Alive": "timeout=5",
|
||||
"Proxy-Authenticate": "Basic",
|
||||
"Proxy-Authorization": "Basic token",
|
||||
"TE": "trailers",
|
||||
"Trailer": "Expires",
|
||||
"Upgrade": "websocket",
|
||||
"Content-Type": "application/json",
|
||||
"X-Request-ID": "req-1",
|
||||
})
|
||||
|
||||
assert.Equal(t, map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"X-Request-ID": "req-1",
|
||||
}, filtered)
|
||||
}
|
||||
|
||||
func TestProxyHandler_UnknownInterface_DoesNotGuessModel(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
providerSvc.EXPECT().List().Return([]domain.Provider{
|
||||
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||
assert.Equal(t, "https://api.test.com/v1/unknown?trace=1", spec.URL)
|
||||
assert.JSONEq(t, `{"model":"p1/gpt-4","payload":true}`, string(spec.Body))
|
||||
return &conversion.HTTPResponseSpec{
|
||||
StatusCode: http.StatusOK,
|
||||
Headers: map[string]string{"Content-Type": "application/json"},
|
||||
Body: []byte(`{"ok":true}`),
|
||||
}, nil
|
||||
})
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/unknown"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/unknown?trace=1", bytes.NewReader([]byte(`{"model":"p1/gpt-4","payload":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
assert.JSONEq(t, `{"ok":true}`, w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyHandler_InvalidJSON_UsesGatewayError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.JSONEq(t, `{"error":"请求体 JSON 格式错误","code":"INVALID_JSON"}`, w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyHandler_CrossProtocolMultimodal_Unsupported(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("anthropic_p", "claude").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.test", Protocol: "anthropic", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
body := []byte(`{"model":"anthropic_p/claude","messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "UNSUPPORTED_MULTIMODAL")
|
||||
}
|
||||
|
||||
func TestProxyHandler_SameProtocolMultimodal_SmartPassthrough(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||
assert.Contains(t, string(spec.Body), "image_url")
|
||||
assert.Contains(t, string(spec.Body), `"model":"gpt-4"`)
|
||||
return &conversion.HTTPResponseSpec{
|
||||
StatusCode: http.StatusOK,
|
||||
Headers: map[string]string{"Content-Type": "application/json"},
|
||||
Body: []byte(`{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`),
|
||||
}, nil
|
||||
})
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
statsSvc.EXPECT().Record(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
body := []byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(body))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "p1/gpt-4")
|
||||
}
|
||||
|
||||
func TestProxyHandler_RawStreamPassthrough_PreservesSSEFrames(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
providerSvc.EXPECT().List().Return([]domain.Provider{
|
||||
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||
assert.Contains(t, string(spec.Body), `"model":"gpt-4"`)
|
||||
ch := make(chan provider.StreamEvent, 3)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ch <- provider.StreamEvent{Data: []byte("data: {\"model\":\"gpt-4\",\"choices\":[]}\n\n")}
|
||||
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
|
||||
ch <- provider.StreamEvent{Done: true}
|
||||
}()
|
||||
return &provider.StreamResponse{StatusCode: http.StatusOK, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
|
||||
})
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "data: {\"model\":\"gpt-4\",\"choices\":[]}\n\ndata: [DONE]\n\n", w.Body.String())
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -43,6 +44,14 @@ type StreamEvent struct {
|
||||
Done bool
|
||||
}
|
||||
|
||||
// StreamResponse 表示上游流式 HTTP 响应。
|
||||
type StreamResponse struct {
|
||||
StatusCode int
|
||||
Headers map[string]string
|
||||
Body []byte
|
||||
Events <-chan StreamEvent
|
||||
}
|
||||
|
||||
// Client 协议无关的供应商客户端
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
@@ -55,7 +64,7 @@ type Client struct {
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
|
||||
type ProviderClient interface {
|
||||
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
|
||||
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error)
|
||||
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error)
|
||||
}
|
||||
|
||||
// NewClient 创建供应商客户端
|
||||
@@ -116,7 +125,7 @@ func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*co
|
||||
}
|
||||
|
||||
// SendStream 发送流式请求
|
||||
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) {
|
||||
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error) {
|
||||
var bodyReader io.Reader
|
||||
if len(spec.Body) > 0 {
|
||||
bodyReader = bytes.NewReader(spec.Body)
|
||||
@@ -139,23 +148,29 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
|
||||
return nil, pkgErrors.ErrRequestSend.WithCause(err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respHeaders := extractResponseHeaders(resp.Header)
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
defer resp.Body.Close()
|
||||
cancel()
|
||||
errBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d,读取错误响应失败: %w", resp.StatusCode, readErr)
|
||||
return nil, pkgErrors.ErrResponseRead.WithCause(readErr)
|
||||
}
|
||||
if len(errBody) > 0 {
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
|
||||
return &StreamResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: respHeaders,
|
||||
Body: errBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize)
|
||||
go c.readStream(streamCtx, cancel, resp.Body, eventChan)
|
||||
|
||||
return eventChan, nil
|
||||
return &StreamResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: respHeaders,
|
||||
Events: eventChan,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// readStream 读取 SSE 流
|
||||
@@ -208,15 +223,17 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
|
||||
}
|
||||
|
||||
for {
|
||||
idx := bytes.Index(dataBuf, []byte("\n\n"))
|
||||
idx, sepLen := findSSEFrameSeparator(dataBuf)
|
||||
if idx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
rawEvent := dataBuf[:idx]
|
||||
dataBuf = dataBuf[idx+2:]
|
||||
frameEnd := idx + sepLen
|
||||
rawEvent := append([]byte(nil), dataBuf[:frameEnd]...)
|
||||
dataBuf = dataBuf[frameEnd:]
|
||||
|
||||
if bytes.Contains(rawEvent, []byte("data: [DONE]")) {
|
||||
if isSSEDoneFrame(rawEvent) {
|
||||
eventChan <- StreamEvent{Data: rawEvent}
|
||||
eventChan <- StreamEvent{Done: true}
|
||||
return
|
||||
}
|
||||
@@ -225,11 +242,66 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
if len(dataBuf) > 0 {
|
||||
eventChan <- StreamEvent{Data: dataBuf}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isSSEDoneFrame(frame []byte) bool {
|
||||
payload, ok := sseFrameDataPayload(frame)
|
||||
return ok && strings.TrimSpace(payload) == "[DONE]"
|
||||
}
|
||||
|
||||
func sseFrameDataPayload(frame []byte) (string, bool) {
|
||||
text := strings.TrimRight(string(frame), "\r\n")
|
||||
lines := strings.Split(text, "\n")
|
||||
var dataLines []string
|
||||
for _, line := range lines {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
value := strings.TrimPrefix(line, "data:")
|
||||
if strings.HasPrefix(value, " ") {
|
||||
value = value[1:]
|
||||
}
|
||||
dataLines = append(dataLines, value)
|
||||
}
|
||||
}
|
||||
if len(dataLines) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return strings.Join(dataLines, "\n"), true
|
||||
}
|
||||
|
||||
func extractResponseHeaders(header http.Header) map[string]string {
|
||||
respHeaders := make(map[string]string)
|
||||
for k, vs := range header {
|
||||
if len(vs) > 0 {
|
||||
respHeaders[k] = vs[0]
|
||||
}
|
||||
}
|
||||
return respHeaders
|
||||
}
|
||||
|
||||
func findSSEFrameSeparator(data []byte) (int, int) {
|
||||
lf := bytes.Index(data, []byte("\n\n"))
|
||||
crlf := bytes.Index(data, []byte("\r\n\r\n"))
|
||||
switch {
|
||||
case lf < 0 && crlf < 0:
|
||||
return -1, 0
|
||||
case lf < 0:
|
||||
return crlf, 4
|
||||
case crlf < 0:
|
||||
return lf, 2
|
||||
case crlf <= lf:
|
||||
return crlf, 4
|
||||
default:
|
||||
return lf, 2
|
||||
}
|
||||
}
|
||||
|
||||
// isNetworkError 判断是否为网络相关错误
|
||||
func isNetworkError(err error) bool {
|
||||
if err == nil {
|
||||
|
||||
@@ -110,11 +110,13 @@ func TestClient_SendStream_CreatesChannel(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, eventChan)
|
||||
require.NotNil(t, streamResp)
|
||||
require.Equal(t, http.StatusOK, streamResp.StatusCode)
|
||||
require.NotNil(t, streamResp.Events)
|
||||
|
||||
for range eventChan {
|
||||
for range streamResp.Events {
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,8 +134,10 @@ func TestClient_SendStream_ErrorResponse(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
_, err := client.SendStream(context.Background(), spec)
|
||||
assert.Error(t, err)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, streamResp)
|
||||
assert.Equal(t, http.StatusInternalServerError, streamResp.StatusCode)
|
||||
}
|
||||
|
||||
func TestClient_SendStream_SSEEvents(t *testing.T) {
|
||||
@@ -164,12 +168,13 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
|
||||
Body: []byte(`{"model":"gpt-4","messages":[],"stream":true}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, streamResp)
|
||||
|
||||
var dataEvents [][]byte
|
||||
var doneEvents int
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
switch {
|
||||
case event.Done:
|
||||
doneEvents++
|
||||
@@ -180,9 +185,56 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, len(dataEvents), "expected exactly 2 data events from SSE stream")
|
||||
assert.Equal(t, 3, len(dataEvents), "expected 2 data frames plus DONE frame from SSE stream")
|
||||
assert.Contains(t, string(dataEvents[0]), "Hello")
|
||||
assert.Contains(t, string(dataEvents[1]), "World")
|
||||
assert.Contains(t, string(dataEvents[2]), "[DONE]")
|
||||
assert.Equal(t, 1, doneEvents)
|
||||
assert.Contains(t, string(dataEvents[0]), "\n\n")
|
||||
}
|
||||
|
||||
func TestClient_SendStream_DONEOnlyWhenDataPayloadEqualsDone(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
_, err := w.Write([]byte("data: {\"text\":\"data: [DONE] is plain text\"}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
_, err = w.Write([]byte("data: [DONE]\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, streamResp)
|
||||
|
||||
var dataEvents [][]byte
|
||||
var doneEvents int
|
||||
for event := range streamResp.Events {
|
||||
switch {
|
||||
case event.Done:
|
||||
doneEvents++
|
||||
case event.Error != nil:
|
||||
t.Fatalf("unexpected error: %v", event.Error)
|
||||
default:
|
||||
dataEvents = append(dataEvents, event.Data)
|
||||
}
|
||||
}
|
||||
|
||||
require.Len(t, dataEvents, 2)
|
||||
assert.Contains(t, string(dataEvents[0]), "plain text")
|
||||
assert.Contains(t, string(dataEvents[1]), "[DONE]")
|
||||
assert.Equal(t, 1, doneEvents)
|
||||
}
|
||||
|
||||
@@ -203,13 +255,13 @@ func TestClient_SendStream_ContextCancellation(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(ctx, spec)
|
||||
streamResp, err := client.SendStream(ctx, spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
cancel()
|
||||
|
||||
var gotError bool
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
if event.Error != nil {
|
||||
gotError = true
|
||||
}
|
||||
@@ -264,12 +316,12 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
var dataCount int
|
||||
var doneCount int
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
switch {
|
||||
case event.Done:
|
||||
doneCount++
|
||||
@@ -279,7 +331,7 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
|
||||
dataCount++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, dataCount, "expected exactly 1 data event from slow SSE")
|
||||
assert.Equal(t, 2, dataCount, "expected 1 data frame plus DONE frame from slow SSE")
|
||||
assert.Equal(t, 1, doneCount, "expected exactly 1 done event from slow SSE")
|
||||
}
|
||||
|
||||
@@ -308,19 +360,19 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
var dataEvents int
|
||||
var doneEvents int
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
if event.Done {
|
||||
doneEvents++
|
||||
} else {
|
||||
dataEvents++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, dataEvents, "expected exactly 2 data events from split SSE")
|
||||
assert.Equal(t, 3, dataEvents, "expected 2 data frames plus DONE frame from split SSE")
|
||||
assert.Equal(t, 1, doneEvents)
|
||||
}
|
||||
|
||||
@@ -397,11 +449,11 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
var gotData bool
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
if event.Error != nil {
|
||||
} else if !event.Done {
|
||||
gotData = true
|
||||
|
||||
@@ -521,15 +521,14 @@ func TestConversion_OldRoutes_Return404(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", strings.NewReader(`{"model":"test"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
// Gin 路由匹配但协议不支持返回 400
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
|
||||
// 旧 Anthropic 路由
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/v1/messages", strings.NewReader(`{"model":"test"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
// ============ Provider Protocol 字段测试 ============
|
||||
|
||||
@@ -59,10 +59,10 @@ func (mr *MockProviderClientMockRecorder) Send(ctx, spec any) *gomock.Call {
|
||||
}
|
||||
|
||||
// SendStream mocks base method.
|
||||
func (m *MockProviderClient) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
||||
func (m *MockProviderClient) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SendStream", ctx, spec)
|
||||
ret0, _ := ret[0].(<-chan provider.StreamEvent)
|
||||
ret0, _ := ret[0].(*provider.StreamResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user