1
0

refactor: 实现 ConversionEngine 协议转换引擎,替代旧 protocol 包

引入 Canonical Model 和 ProtocolAdapter 架构,支持 OpenAI/Anthropic 协议间
无缝转换,统一 ProxyHandler 替代分散的 OpenAI/Anthropic Handler,简化
ProviderClient 为协议无关的 HTTP 发送器,Provider 新增 protocol 字段。
This commit is contained in:
2026-04-20 00:36:27 +08:00
parent 26810d9410
commit 1dac347d3b
65 changed files with 9690 additions and 2139 deletions

View File

@@ -0,0 +1,371 @@
package handler
import (
"bufio"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
"nex/backend/internal/service"
)
// ProxyHandler 统一代理处理器
type ProxyHandler struct {
engine *conversion.ConversionEngine
client provider.ProviderClient
routingService service.RoutingService
providerService service.ProviderService
statsService service.StatsService
logger *zap.Logger
}
// NewProxyHandler 创建统一代理处理器
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService) *ProxyHandler {
return &ProxyHandler{
engine: engine,
client: client,
routingService: routingService,
providerService: providerService,
statsService: statsService,
logger: zap.L(),
}
}
// HandleProxy 处理代理请求
func (h *ProxyHandler) HandleProxy(c *gin.Context) {
// 从 URL 提取 clientProtocol: /{protocol}/v1/...
clientProtocol := c.Param("protocol")
if clientProtocol == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"})
return
}
// 原始路径: /v1/{path}
path := c.Param("path")
if strings.HasPrefix(path, "/") {
path = path[1:]
}
nativePath := "/v1/" + path
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
return
}
// 解析 model 名称(从 JSON body 中提取GET 请求无 body
modelName := ""
if len(body) > 0 {
modelName = extractModelName(body)
}
// 构建输入 HTTPRequestSpec
inSpec := conversion.HTTPRequestSpec{
URL: nativePath,
Method: c.Request.Method,
Headers: extractHeaders(c),
Body: body,
}
// 路由
routeResult, err := h.routingService.Route(modelName)
if err != nil {
// GET 请求或无法提取 model 时,直接转发到上游
if len(body) == 0 || modelName == "" {
h.forwardPassthrough(c, inSpec, clientProtocol)
return
}
h.writeError(c, err, clientProtocol)
return
}
// 确定 providerProtocol
providerProtocol := routeResult.Provider.Protocol
if providerProtocol == "" {
providerProtocol = "openai"
}
// 构建 TargetProvider
targetProvider := conversion.NewTargetProvider(
routeResult.Provider.BaseURL,
routeResult.Provider.APIKey,
routeResult.Model.ModelName,
)
// 判断是否流式
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
if isStream {
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
} else {
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
}
}
// handleNonStream 处理非流式请求
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
// 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
h.logger.Error("转换请求失败", zap.String("error", err.Error()))
h.writeConversionError(c, err, clientProtocol)
return
}
// 发送请求
resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil {
h.logger.Error("发送请求失败", zap.String("error", err.Error()))
h.writeConversionError(c, err, clientProtocol)
return
}
// 转换响应
interfaceType, _ := h.engine.DetectInterfaceType(inSpec.URL, clientProtocol)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, interfaceType)
if err != nil {
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
h.writeConversionError(c, err, clientProtocol)
return
}
// 设置响应头
for k, v := range convertedResp.Headers {
c.Header(k, v)
}
if c.GetHeader("Content-Type") == "" {
c.Header("Content-Type", "application/json")
}
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
}()
}
// handleStream 处理流式请求
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
// 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
// 创建流式转换器
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
// 发送流式请求
eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
writer := bufio.NewWriter(c.Writer)
for event := range eventChan {
if event.Error != nil {
h.logger.Error("流读取错误", zap.String("error", event.Error.Error()))
break
}
if event.Done {
// flush 转换器
chunks := streamConverter.Flush()
for _, chunk := range chunks {
writer.Write(chunk)
writer.Flush()
}
break
}
chunks := streamConverter.ProcessChunk(event.Data)
for _, chunk := range chunks {
writer.Write(chunk)
writer.Flush()
}
}
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
}()
}
// isStreamRequest 判断是否流式请求
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol)
if ifaceType != conversion.InterfaceTypeChat {
return false
}
for i, b := range body {
if b == '"' && i+8 <= len(body) {
if string(body[i:i+8]) == `"stream"` {
for j := i + 8; j < len(body) && j < i+20; j++ {
if body[j] == 't' && j+3 < len(body) && string(body[j:j+4]) == "true" {
return true
}
}
}
}
}
return false
}
// writeConversionError 写入转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
if convErr, ok := err.(*conversion.ConversionError); ok {
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol)
c.Data(statusCode, "application/json", body)
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
// writeError 写入路由错误
func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
}
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) {
registry := h.engine.GetRegistry()
adapter, err := registry.Get(clientProtocol)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
return
}
providers, err := h.providerService.List()
if err != nil || len(providers) == 0 {
h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL))
c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"})
return
}
p := providers[0]
providerProtocol := p.Protocol
if providerProtocol == "" {
providerProtocol = "openai"
}
ifaceType := adapter.DetectInterfaceType(inSpec.URL)
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
var outSpec *conversion.HTTPRequestSpec
if clientProtocol == providerProtocol {
upstreamURL := p.BaseURL + inSpec.URL
headers := adapter.BuildHeaders(targetProvider)
if _, ok := headers["Content-Type"]; !ok {
headers["Content-Type"] = "application/json"
}
outSpec = &conversion.HTTPRequestSpec{
URL: upstreamURL,
Method: inSpec.Method,
Headers: headers,
Body: inSpec.Body,
}
} else {
outSpec, err = h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
}
resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
for k, v := range convertedResp.Headers {
c.Header(k, v)
}
if c.GetHeader("Content-Type") == "" {
c.Header("Content-Type", "application/json")
}
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
}
// extractModelName 从 JSON body 中提取 model
func extractModelName(body []byte) string {
inQuote := false
escaped := false
keyStart := -1
keyEnd := -1
lookingForKey := true
lookingForValue := false
valueStart := -1
for i := 0; i < len(body); i++ {
b := body[i]
if escaped {
escaped = false
continue
}
if b == '\\' {
escaped = true
continue
}
if b == '"' {
if !inQuote {
inQuote = true
if lookingForKey {
keyStart = i + 1
}
if lookingForValue {
valueStart = i + 1
}
} else {
inQuote = false
if lookingForKey && keyStart >= 0 {
keyEnd = i
if string(body[keyStart:keyEnd]) == "model" {
lookingForKey = false
lookingForValue = true
}
} else if lookingForValue && valueStart >= 0 {
return string(body[valueStart:i])
}
}
}
if !inQuote && lookingForValue && b == ':' {
// 等待值开始
}
}
return ""
}
// extractHeaders 从 Gin context 提取请求头
func extractHeaders(c *gin.Context) map[string]string {
headers := make(map[string]string)
for k, vs := range c.Request.Header {
if len(vs) > 0 {
headers[k] = vs[0]
}
}
return headers
}