1
0
Files
nex/backend/internal/handler/proxy_handler.go
lanyuanxiaoyao 1dac347d3b refactor: 实现 ConversionEngine 协议转换引擎,替代旧 protocol 包
引入 Canonical Model 和 ProtocolAdapter 架构,支持 OpenAI/Anthropic 协议间
无缝转换,统一 ProxyHandler 替代分散的 OpenAI/Anthropic Handler,简化
ProviderClient 为协议无关的 HTTP 发送器,Provider 新增 protocol 字段。
2026-04-20 00:36:27 +08:00

372 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"bufio"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
"nex/backend/internal/service"
)
// ProxyHandler 统一代理处理器
type ProxyHandler struct {
engine *conversion.ConversionEngine
client provider.ProviderClient
routingService service.RoutingService
providerService service.ProviderService
statsService service.StatsService
logger *zap.Logger
}
// NewProxyHandler 创建统一代理处理器
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService) *ProxyHandler {
return &ProxyHandler{
engine: engine,
client: client,
routingService: routingService,
providerService: providerService,
statsService: statsService,
logger: zap.L(),
}
}
// HandleProxy 处理代理请求
func (h *ProxyHandler) HandleProxy(c *gin.Context) {
// 从 URL 提取 clientProtocol: /{protocol}/v1/...
clientProtocol := c.Param("protocol")
if clientProtocol == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"})
return
}
// 原始路径: /v1/{path}
path := c.Param("path")
if strings.HasPrefix(path, "/") {
path = path[1:]
}
nativePath := "/v1/" + path
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
return
}
// 解析 model 名称(从 JSON body 中提取GET 请求无 body
modelName := ""
if len(body) > 0 {
modelName = extractModelName(body)
}
// 构建输入 HTTPRequestSpec
inSpec := conversion.HTTPRequestSpec{
URL: nativePath,
Method: c.Request.Method,
Headers: extractHeaders(c),
Body: body,
}
// 路由
routeResult, err := h.routingService.Route(modelName)
if err != nil {
// GET 请求或无法提取 model 时,直接转发到上游
if len(body) == 0 || modelName == "" {
h.forwardPassthrough(c, inSpec, clientProtocol)
return
}
h.writeError(c, err, clientProtocol)
return
}
// 确定 providerProtocol
providerProtocol := routeResult.Provider.Protocol
if providerProtocol == "" {
providerProtocol = "openai"
}
// 构建 TargetProvider
targetProvider := conversion.NewTargetProvider(
routeResult.Provider.BaseURL,
routeResult.Provider.APIKey,
routeResult.Model.ModelName,
)
// 判断是否流式
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
if isStream {
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
} else {
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
}
}
// handleNonStream 处理非流式请求
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
// 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
h.logger.Error("转换请求失败", zap.String("error", err.Error()))
h.writeConversionError(c, err, clientProtocol)
return
}
// 发送请求
resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil {
h.logger.Error("发送请求失败", zap.String("error", err.Error()))
h.writeConversionError(c, err, clientProtocol)
return
}
// 转换响应
interfaceType, _ := h.engine.DetectInterfaceType(inSpec.URL, clientProtocol)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, interfaceType)
if err != nil {
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
h.writeConversionError(c, err, clientProtocol)
return
}
// 设置响应头
for k, v := range convertedResp.Headers {
c.Header(k, v)
}
if c.GetHeader("Content-Type") == "" {
c.Header("Content-Type", "application/json")
}
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
}()
}
// handleStream 处理流式请求
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
// 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
// 创建流式转换器
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
// 发送流式请求
eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
writer := bufio.NewWriter(c.Writer)
for event := range eventChan {
if event.Error != nil {
h.logger.Error("流读取错误", zap.String("error", event.Error.Error()))
break
}
if event.Done {
// flush 转换器
chunks := streamConverter.Flush()
for _, chunk := range chunks {
writer.Write(chunk)
writer.Flush()
}
break
}
chunks := streamConverter.ProcessChunk(event.Data)
for _, chunk := range chunks {
writer.Write(chunk)
writer.Flush()
}
}
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
}()
}
// isStreamRequest 判断是否流式请求
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol)
if ifaceType != conversion.InterfaceTypeChat {
return false
}
for i, b := range body {
if b == '"' && i+8 <= len(body) {
if string(body[i:i+8]) == `"stream"` {
for j := i + 8; j < len(body) && j < i+20; j++ {
if body[j] == 't' && j+3 < len(body) && string(body[j:j+4]) == "true" {
return true
}
}
}
}
}
return false
}
// writeConversionError 写入转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
if convErr, ok := err.(*conversion.ConversionError); ok {
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol)
c.Data(statusCode, "application/json", body)
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
// writeError 写入路由错误
func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
}
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) {
registry := h.engine.GetRegistry()
adapter, err := registry.Get(clientProtocol)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
return
}
providers, err := h.providerService.List()
if err != nil || len(providers) == 0 {
h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL))
c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"})
return
}
p := providers[0]
providerProtocol := p.Protocol
if providerProtocol == "" {
providerProtocol = "openai"
}
ifaceType := adapter.DetectInterfaceType(inSpec.URL)
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
var outSpec *conversion.HTTPRequestSpec
if clientProtocol == providerProtocol {
upstreamURL := p.BaseURL + inSpec.URL
headers := adapter.BuildHeaders(targetProvider)
if _, ok := headers["Content-Type"]; !ok {
headers["Content-Type"] = "application/json"
}
outSpec = &conversion.HTTPRequestSpec{
URL: upstreamURL,
Method: inSpec.Method,
Headers: headers,
Body: inSpec.Body,
}
} else {
outSpec, err = h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
}
resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
for k, v := range convertedResp.Headers {
c.Header(k, v)
}
if c.GetHeader("Content-Type") == "" {
c.Header("Content-Type", "application/json")
}
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
}
// extractModelName 从 JSON body 中提取 model
func extractModelName(body []byte) string {
inQuote := false
escaped := false
keyStart := -1
keyEnd := -1
lookingForKey := true
lookingForValue := false
valueStart := -1
for i := 0; i < len(body); i++ {
b := body[i]
if escaped {
escaped = false
continue
}
if b == '\\' {
escaped = true
continue
}
if b == '"' {
if !inQuote {
inQuote = true
if lookingForKey {
keyStart = i + 1
}
if lookingForValue {
valueStart = i + 1
}
} else {
inQuote = false
if lookingForKey && keyStart >= 0 {
keyEnd = i
if string(body[keyStart:keyEnd]) == "model" {
lookingForKey = false
lookingForValue = true
}
} else if lookingForValue && valueStart >= 0 {
return string(body[valueStart:i])
}
}
}
if !inQuote && lookingForValue && b == ':' {
// 等待值开始
}
}
return ""
}
// extractHeaders 从 Gin context 提取请求头
func extractHeaders(c *gin.Context) map[string]string {
headers := make(map[string]string)
for k, vs := range c.Request.Header {
if len(vs) > 0 {
headers[k] = vs[0]
}
}
return headers
}