refactor: 实现 ConversionEngine 协议转换引擎,替代旧 protocol 包
引入 Canonical Model 和 ProtocolAdapter 架构,支持 OpenAI/Anthropic 协议间 无缝转换,统一 ProxyHandler 替代分散的 OpenAI/Anthropic Handler,简化 ProviderClient 为协议无关的 HTTP 发送器,Provider 新增 protocol 字段。
This commit is contained in:
@@ -1,217 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/anthropic"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// AnthropicHandler Anthropic 协议处理器
|
||||
type AnthropicHandler struct {
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
statsService service.StatsService
|
||||
}
|
||||
|
||||
// NewAnthropicHandler 创建 Anthropic 处理器
|
||||
func NewAnthropicHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *AnthropicHandler {
|
||||
return &AnthropicHandler{
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
statsService: statsService,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleMessages 处理 Messages 请求
|
||||
func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
|
||||
var req anthropic.MessagesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: "无效的请求格式: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 请求验证
|
||||
if validationErrors := anthropic.ValidateRequest(&req); validationErrors != nil {
|
||||
errMsg := formatValidationErrors(validationErrors)
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: errMsg,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.checkMultimodalContent(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
openaiReq, err := anthropic.ConvertRequest(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: "请求转换失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
routeResult, err := h.routingService.Route(openaiReq.Model)
|
||||
if err != nil {
|
||||
h.handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
h.handleStreamRequest(c, openaiReq, routeResult)
|
||||
} else {
|
||||
h.handleNonStreamRequest(c, openaiReq, routeResult)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
openaiResp, err := h.client.SendRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
anthropicResp, err := anthropic.ConvertResponse(openaiResp)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: "响应转换失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model)
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, anthropicResp)
|
||||
}
|
||||
|
||||
func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
converter := anthropic.NewStreamConverter(
|
||||
fmt.Sprintf("msg_%s", routeResult.Provider.ID),
|
||||
openaiReq.Model,
|
||||
)
|
||||
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if event.Done {
|
||||
break
|
||||
}
|
||||
|
||||
chunk, err := openai.NewAdapter().ParseStreamChunk(event.Data)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
anthropicEvents, err := converter.ConvertChunk(chunk)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ae := range anthropicEvents {
|
||||
eventStr, err := anthropic.SerializeEvent(ae)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
writer.WriteString(eventStr)
|
||||
writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model)
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest) error {
|
||||
for _, msg := range req.Messages {
|
||||
for _, block := range msg.Content {
|
||||
if block.Type == "image" {
|
||||
return fmt.Errorf("MVP 不支持多模态内容(图片)")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *AnthropicHandler) handleError(c *gin.Context, err error) {
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
c.JSON(appErr.HTTPStatus, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "not_found_error",
|
||||
Message: appErr.Message,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "internal_error",
|
||||
Message: "内部错误: " + err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -13,7 +15,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
@@ -34,8 +35,8 @@ func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error
|
||||
}
|
||||
|
||||
type mockStatsService struct {
|
||||
err error
|
||||
stats []domain.UsageStats
|
||||
err error
|
||||
stats []domain.UsageStats
|
||||
aggrResult []map[string]interface{}
|
||||
}
|
||||
|
||||
@@ -84,61 +85,14 @@ func (m *mockModelService) Update(id string, updates map[string]interface{}) err
|
||||
func (m *mockModelService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockProviderClient struct {
|
||||
resp *openai.ChatCompletionResponse
|
||||
eventChan chan provider.StreamEvent
|
||||
err error
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderClient) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) {
|
||||
return m.resp, m.err
|
||||
func (m *mockProviderClient) Send(ctx context.Context, spec interface{}) (interface{}, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
func (m *mockProviderClient) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan provider.StreamEvent, error) {
|
||||
return m.eventChan, m.err
|
||||
}
|
||||
|
||||
// ============ OpenAI Handler 测试 ============
|
||||
|
||||
func TestOpenAIHandler_HandleChatCompletions_InvalidJSON(t *testing.T) {
|
||||
h := NewOpenAIHandler(nil, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte("invalid")))
|
||||
|
||||
h.HandleChatCompletions(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestOpenAIHandler_HandleChatCompletions_ValidationError(t *testing.T) {
|
||||
h := NewOpenAIHandler(nil, nil, nil)
|
||||
|
||||
// 缺少 model 字段
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"messages": []map[string]string{{"role": "user", "content": "hi"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.HandleChatCompletions(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestOpenAIHandler_HandleChatCompletions_RouteError(t *testing.T) {
|
||||
routingSvc := &mockRoutingService{err: appErrors.ErrModelNotFound}
|
||||
h := NewOpenAIHandler(nil, routingSvc, nil)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"model": "nonexistent",
|
||||
"messages": []map[string]string{{"role": "user", "content": "hi"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.HandleChatCompletions(c)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
func (m *mockProviderClient) SendStream(ctx context.Context, spec interface{}) (<-chan provider.StreamEvent, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
// ============ Provider Handler 测试 ============
|
||||
@@ -283,8 +237,16 @@ func TestFormatValidationErrors(t *testing.T) {
|
||||
"model": "模型名称不能为空",
|
||||
"messages": "消息列表不能为空",
|
||||
}
|
||||
result := formatValidationErrors(errs)
|
||||
result := formatMapErrors(errs)
|
||||
require.Contains(t, result, "请求验证失败")
|
||||
require.Contains(t, result, "model")
|
||||
require.Contains(t, result, "messages")
|
||||
}
|
||||
|
||||
func formatMapErrors(errs map[string]string) string {
|
||||
parts := make([]string, 0, len(errs))
|
||||
for field, msg := range errs {
|
||||
parts = append(parts, fmt.Sprintf("%s: %s", field, msg))
|
||||
}
|
||||
return "请求验证失败: " + strings.Join(parts, "; ")
|
||||
}
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// OpenAIHandler OpenAI 协议处理器
|
||||
type OpenAIHandler struct {
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
statsService service.StatsService
|
||||
}
|
||||
|
||||
// NewOpenAIHandler 创建 OpenAI 处理器
|
||||
func NewOpenAIHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *OpenAIHandler {
|
||||
return &OpenAIHandler{
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
statsService: statsService,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleChatCompletions 处理 Chat Completions 请求
|
||||
func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
|
||||
var req openai.ChatCompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "无效的请求格式: " + err.Error(),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 请求验证
|
||||
if validationErrors := openai.ValidateRequest(&req); validationErrors != nil {
|
||||
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: formatValidationErrors(validationErrors),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
routeResult, err := h.routingService.Route(req.Model)
|
||||
if err != nil {
|
||||
h.handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
h.handleStreamRequest(c, &req, routeResult)
|
||||
} else {
|
||||
h.handleNonStreamRequest(c, &req, routeResult)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
Type: "api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
Type: "api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if event.Done {
|
||||
writer.WriteString("data: [DONE]\n\n")
|
||||
writer.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
writer.WriteString("data: ")
|
||||
writer.Write(event.Data)
|
||||
writer.WriteString("\n\n")
|
||||
writer.Flush()
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *OpenAIHandler) handleError(c *gin.Context, err error) {
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
c.JSON(appErr.HTTPStatus, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: appErr.Message,
|
||||
Type: "invalid_request_error",
|
||||
Code: appErr.Code,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "内部错误: " + err.Error(),
|
||||
Type: "internal_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// formatValidationErrors 将验证错误 map 格式化为字符串
|
||||
func formatValidationErrors(errors map[string]string) string {
|
||||
parts := make([]string, 0, len(errors))
|
||||
for field, msg := range errors {
|
||||
parts = append(parts, fmt.Sprintf("%s: %s", field, msg))
|
||||
}
|
||||
return "请求验证失败: " + strings.Join(parts, "; ")
|
||||
}
|
||||
@@ -26,10 +26,11 @@ func NewProviderHandler(providerService service.ProviderService) *ProviderHandle
|
||||
// CreateProvider 创建供应商
|
||||
func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
var req struct {
|
||||
ID string `json:"id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
ID string `json:"id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Protocol string `json:"protocol"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
@@ -39,11 +40,17 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
protocol := req.Protocol
|
||||
if protocol == "" {
|
||||
protocol = "openai"
|
||||
}
|
||||
|
||||
provider := &domain.Provider{
|
||||
ID: req.ID,
|
||||
Name: req.Name,
|
||||
APIKey: req.APIKey,
|
||||
BaseURL: req.BaseURL,
|
||||
ID: req.ID,
|
||||
Name: req.Name,
|
||||
APIKey: req.APIKey,
|
||||
BaseURL: req.BaseURL,
|
||||
Protocol: protocol,
|
||||
}
|
||||
|
||||
err := h.providerService.Create(provider)
|
||||
|
||||
371
backend/internal/handler/proxy_handler.go
Normal file
371
backend/internal/handler/proxy_handler.go
Normal file
@@ -0,0 +1,371 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ProxyHandler 统一代理处理器
|
||||
type ProxyHandler struct {
|
||||
engine *conversion.ConversionEngine
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
providerService service.ProviderService
|
||||
statsService service.StatsService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewProxyHandler 创建统一代理处理器
|
||||
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService) *ProxyHandler {
|
||||
return &ProxyHandler{
|
||||
engine: engine,
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
providerService: providerService,
|
||||
statsService: statsService,
|
||||
logger: zap.L(),
|
||||
}
|
||||
}
|
||||
|
||||
// HandleProxy 处理代理请求
|
||||
func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
// 从 URL 提取 clientProtocol: /{protocol}/v1/...
|
||||
clientProtocol := c.Param("protocol")
|
||||
if clientProtocol == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"})
|
||||
return
|
||||
}
|
||||
|
||||
// 原始路径: /v1/{path}
|
||||
path := c.Param("path")
|
||||
if strings.HasPrefix(path, "/") {
|
||||
path = path[1:]
|
||||
}
|
||||
nativePath := "/v1/" + path
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 model 名称(从 JSON body 中提取,GET 请求无 body)
|
||||
modelName := ""
|
||||
if len(body) > 0 {
|
||||
modelName = extractModelName(body)
|
||||
}
|
||||
|
||||
// 构建输入 HTTPRequestSpec
|
||||
inSpec := conversion.HTTPRequestSpec{
|
||||
URL: nativePath,
|
||||
Method: c.Request.Method,
|
||||
Headers: extractHeaders(c),
|
||||
Body: body,
|
||||
}
|
||||
|
||||
// 路由
|
||||
routeResult, err := h.routingService.Route(modelName)
|
||||
if err != nil {
|
||||
// GET 请求或无法提取 model 时,直接转发到上游
|
||||
if len(body) == 0 || modelName == "" {
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol)
|
||||
return
|
||||
}
|
||||
h.writeError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 确定 providerProtocol
|
||||
providerProtocol := routeResult.Provider.Protocol
|
||||
if providerProtocol == "" {
|
||||
providerProtocol = "openai"
|
||||
}
|
||||
|
||||
// 构建 TargetProvider
|
||||
targetProvider := conversion.NewTargetProvider(
|
||||
routeResult.Provider.BaseURL,
|
||||
routeResult.Provider.APIKey,
|
||||
routeResult.Model.ModelName,
|
||||
)
|
||||
|
||||
// 判断是否流式
|
||||
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
|
||||
|
||||
if isStream {
|
||||
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
|
||||
} else {
|
||||
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStream 处理非流式请求
|
||||
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
|
||||
// 转换请求
|
||||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||
if err != nil {
|
||||
h.logger.Error("转换请求失败", zap.String("error", err.Error()))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.logger.Error("发送请求失败", zap.String("error", err.Error()))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 转换响应
|
||||
interfaceType, _ := h.engine.DetectInterfaceType(inSpec.URL, clientProtocol)
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, interfaceType)
|
||||
if err != nil {
|
||||
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
for k, v := range convertedResp.Headers {
|
||||
c.Header(k, v)
|
||||
}
|
||||
if c.GetHeader("Content-Type") == "" {
|
||||
c.Header("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
|
||||
}()
|
||||
}
|
||||
|
||||
// handleStream 处理流式请求
|
||||
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
|
||||
// 转换请求
|
||||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建流式转换器
|
||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 发送流式请求
|
||||
eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
h.logger.Error("流读取错误", zap.String("error", event.Error.Error()))
|
||||
break
|
||||
}
|
||||
if event.Done {
|
||||
// flush 转换器
|
||||
chunks := streamConverter.Flush()
|
||||
for _, chunk := range chunks {
|
||||
writer.Write(chunk)
|
||||
writer.Flush()
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
chunks := streamConverter.ProcessChunk(event.Data)
|
||||
for _, chunk := range chunks {
|
||||
writer.Write(chunk)
|
||||
writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
|
||||
}()
|
||||
}
|
||||
|
||||
// isStreamRequest 判断是否流式请求
|
||||
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
|
||||
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol)
|
||||
if ifaceType != conversion.InterfaceTypeChat {
|
||||
return false
|
||||
}
|
||||
for i, b := range body {
|
||||
if b == '"' && i+8 <= len(body) {
|
||||
if string(body[i:i+8]) == `"stream"` {
|
||||
for j := i + 8; j < len(body) && j < i+20; j++ {
|
||||
if body[j] == 't' && j+3 < len(body) && string(body[j:j+4]) == "true" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// writeConversionError 写入转换错误
|
||||
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
|
||||
if convErr, ok := err.(*conversion.ConversionError); ok {
|
||||
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol)
|
||||
c.Data(statusCode, "application/json", body)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
// writeError 写入路由错误
|
||||
func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
|
||||
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) {
|
||||
registry := h.engine.GetRegistry()
|
||||
adapter, err := registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.providerService.List()
|
||||
if err != nil || len(providers) == 0 {
|
||||
h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"})
|
||||
return
|
||||
}
|
||||
|
||||
p := providers[0]
|
||||
providerProtocol := p.Protocol
|
||||
if providerProtocol == "" {
|
||||
providerProtocol = "openai"
|
||||
}
|
||||
|
||||
ifaceType := adapter.DetectInterfaceType(inSpec.URL)
|
||||
|
||||
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
|
||||
|
||||
var outSpec *conversion.HTTPRequestSpec
|
||||
if clientProtocol == providerProtocol {
|
||||
upstreamURL := p.BaseURL + inSpec.URL
|
||||
headers := adapter.BuildHeaders(targetProvider)
|
||||
if _, ok := headers["Content-Type"]; !ok {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
outSpec = &conversion.HTTPRequestSpec{
|
||||
URL: upstreamURL,
|
||||
Method: inSpec.Method,
|
||||
Headers: headers,
|
||||
Body: inSpec.Body,
|
||||
}
|
||||
} else {
|
||||
outSpec, err = h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range convertedResp.Headers {
|
||||
c.Header(k, v)
|
||||
}
|
||||
if c.GetHeader("Content-Type") == "" {
|
||||
c.Header("Content-Type", "application/json")
|
||||
}
|
||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
||||
}
|
||||
|
||||
// extractModelName 从 JSON body 中提取 model
|
||||
func extractModelName(body []byte) string {
|
||||
inQuote := false
|
||||
escaped := false
|
||||
keyStart := -1
|
||||
keyEnd := -1
|
||||
lookingForKey := true
|
||||
lookingForValue := false
|
||||
valueStart := -1
|
||||
|
||||
for i := 0; i < len(body); i++ {
|
||||
b := body[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if b == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if b == '"' {
|
||||
if !inQuote {
|
||||
inQuote = true
|
||||
if lookingForKey {
|
||||
keyStart = i + 1
|
||||
}
|
||||
if lookingForValue {
|
||||
valueStart = i + 1
|
||||
}
|
||||
} else {
|
||||
inQuote = false
|
||||
if lookingForKey && keyStart >= 0 {
|
||||
keyEnd = i
|
||||
if string(body[keyStart:keyEnd]) == "model" {
|
||||
lookingForKey = false
|
||||
lookingForValue = true
|
||||
}
|
||||
} else if lookingForValue && valueStart >= 0 {
|
||||
return string(body[valueStart:i])
|
||||
}
|
||||
}
|
||||
}
|
||||
if !inQuote && lookingForValue && b == ':' {
|
||||
// 等待值开始
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractHeaders 从 Gin context 提取请求头
|
||||
func extractHeaders(c *gin.Context) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
for k, vs := range c.Request.Header {
|
||||
if len(vs) > 0 {
|
||||
headers[k] = vs[0]
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
Reference in New Issue
Block a user