1
0
Files
nex/backend/internal/handler/proxy_handler.go
lanyuanxiaoyao 395887667d feat: 实现统一模型 ID 机制
实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。

核心变更:
- 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID
- 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束
- Repository 层:FindByProviderAndModelName、ListEnabled 方法
- Service 层:联合唯一校验、provider ID 字符集校验
- Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法
- Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合
- 新增 error-responses、unified-model-id 规范

测试覆盖:
- 单元测试:modelid、conversion、handler、service、repository
- 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换
- 迁移测试:UUID 主键、UNIQUE 约束、级联删除

OpenSpec:
- 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id
- 同步 11 个 delta specs 到 main specs
- 新增 error-responses、unified-model-id 规范文件
2026-04-21 18:14:10 +08:00

435 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"bufio"
"encoding/json"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
"nex/backend/internal/service"
"nex/backend/pkg/modelid"
)
// ProxyHandler 统一代理处理器
type ProxyHandler struct {
engine *conversion.ConversionEngine
client provider.ProviderClient
routingService service.RoutingService
providerService service.ProviderService
statsService service.StatsService
logger *zap.Logger
}
// NewProxyHandler 创建统一代理处理器
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService) *ProxyHandler {
return &ProxyHandler{
engine: engine,
client: client,
routingService: routingService,
providerService: providerService,
statsService: statsService,
logger: zap.L(),
}
}
// HandleProxy 处理代理请求
func (h *ProxyHandler) HandleProxy(c *gin.Context) {
// 从 URL 提取 clientProtocol: /{protocol}/v1/...
clientProtocol := c.Param("protocol")
if clientProtocol == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"})
return
}
// 原始路径: /v1/{path}
path := c.Param("path")
if strings.HasPrefix(path, "/") {
path = path[1:]
}
nativePath := "/v1/" + path
// 获取 client adapter
registry := h.engine.GetRegistry()
clientAdapter, err := registry.Get(clientProtocol)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
return
}
// 检测接口类型
ifaceType := clientAdapter.DetectInterfaceType(nativePath)
// 处理 Models 接口:本地聚合
if ifaceType == conversion.InterfaceTypeModels {
h.handleModelsList(c, clientAdapter)
return
}
// 处理 ModelInfo 接口:本地查询
if ifaceType == conversion.InterfaceTypeModelInfo {
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的模型 ID 格式"})
return
}
h.handleModelInfo(c, unifiedID, clientAdapter)
return
}
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
return
}
// 解析统一模型 ID使用 adapter.ExtractModelName
var providerID, modelName string
if len(body) > 0 {
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
if err == nil && unifiedID != "" {
pid, mn, err := modelid.ParseUnifiedModelID(unifiedID)
if err == nil {
providerID = pid
modelName = mn
}
}
}
// 构建输入 HTTPRequestSpec
inSpec := conversion.HTTPRequestSpec{
URL: nativePath,
Method: c.Request.Method,
Headers: extractHeaders(c),
Body: body,
}
// 路由
routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
if err != nil {
// GET 请求或无法提取 model 时,直接转发到上游
if len(body) == 0 || modelName == "" {
h.forwardPassthrough(c, inSpec, clientProtocol)
return
}
h.writeError(c, err, clientProtocol)
return
}
// 确定 providerProtocol
providerProtocol := routeResult.Provider.Protocol
if providerProtocol == "" {
providerProtocol = "openai"
}
// 构建 TargetProvider
// 注意ModelName 字段用于 Smart Passthrough 场景改写请求体
// 同协议:请求体中的统一 ID 会被改写为 ModelName上游名
// 跨协议:全量转换时 ModelName 会被编码到请求体中
targetProvider := conversion.NewTargetProvider(
routeResult.Provider.BaseURL,
routeResult.Provider.APIKey,
routeResult.Model.ModelName, // 上游模型名,用于请求改写
)
// 判断是否流式
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
// 计算统一模型 ID用于响应覆写
unifiedModelID := routeResult.Model.UnifiedModelID()
if isStream {
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
} else {
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
}
}
// handleNonStream 处理非流式请求
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
// 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
h.logger.Error("转换请求失败", zap.String("error", err.Error()))
h.writeConversionError(c, err, clientProtocol)
return
}
// 发送请求
resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil {
h.logger.Error("发送请求失败", zap.String("error", err.Error()))
h.writeConversionError(c, err, clientProtocol)
return
}
// 转换响应,传入 modelOverride跨协议场景覆写 model 字段)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
if err != nil {
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
h.writeConversionError(c, err, clientProtocol)
return
}
// 设置响应头
for k, v := range convertedResp.Headers {
c.Header(k, v)
}
if c.GetHeader("Content-Type") == "" {
c.Header("Content-Type", "application/json")
}
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
}()
}
// handleStream 处理流式请求
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
// 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
// 创建流式转换器,传入 modelOverride跨协议场景覆写 model 字段)
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
// 发送流式请求
eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
writer := bufio.NewWriter(c.Writer)
for event := range eventChan {
if event.Error != nil {
h.logger.Error("流读取错误", zap.String("error", event.Error.Error()))
break
}
if event.Done {
// flush 转换器
chunks := streamConverter.Flush()
for _, chunk := range chunks {
writer.Write(chunk)
writer.Flush()
}
break
}
chunks := streamConverter.ProcessChunk(event.Data)
for _, chunk := range chunks {
writer.Write(chunk)
writer.Flush()
}
}
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
}()
}
// isStreamRequest 判断是否流式请求
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol)
if ifaceType != conversion.InterfaceTypeChat {
return false
}
var req struct {
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
return false
}
return req.Stream
}
// handleModelsList 处理 GET /v1/models 本地聚合
func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.ProtocolAdapter) {
// 从数据库查询所有启用的模型
models, err := h.providerService.ListEnabledModels()
if err != nil {
h.logger.Error("查询启用模型失败", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
return
}
// 构建 CanonicalModelList
modelList := &canonical.CanonicalModelList{
Models: make([]canonical.CanonicalModel, 0, len(models)),
}
for _, m := range models {
modelList.Models = append(modelList.Models, canonical.CanonicalModel{
ID: m.UnifiedModelID(),
Name: m.ModelName,
Created: m.CreatedAt.Unix(),
OwnedBy: m.ProviderID,
})
}
// 使用 adapter 编码返回
body, err := adapter.EncodeModelsResponse(modelList)
if err != nil {
h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
return
}
c.Data(http.StatusOK, "application/json", body)
}
// handleModelInfo 处理 GET /v1/models/{unified_id} 本地查询
func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter conversion.ProtocolAdapter) {
// 解析统一模型 ID
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的统一模型 ID 格式",
"code": "INVALID_MODEL_ID",
})
return
}
// 从数据库查询模型
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "模型未找到"})
return
}
// 构建 CanonicalModelInfo
modelInfo := &canonical.CanonicalModelInfo{
ID: model.UnifiedModelID(),
Name: model.ModelName,
Created: model.CreatedAt.Unix(),
OwnedBy: model.ProviderID,
}
// 使用 adapter 编码返回
body, err := adapter.EncodeModelInfoResponse(modelInfo)
if err != nil {
h.logger.Error("编码 ModelInfo 响应失败", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
return
}
c.Data(http.StatusOK, "application/json", body)
}
// writeConversionError 写入转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
if convErr, ok := err.(*conversion.ConversionError); ok {
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol)
c.Data(statusCode, "application/json", body)
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
// writeError 写入路由错误
func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
}
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) {
registry := h.engine.GetRegistry()
adapter, err := registry.Get(clientProtocol)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
return
}
providers, err := h.providerService.List()
if err != nil || len(providers) == 0 {
h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL))
c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"})
return
}
p := providers[0]
providerProtocol := p.Protocol
if providerProtocol == "" {
providerProtocol = "openai"
}
ifaceType := adapter.DetectInterfaceType(inSpec.URL)
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
var outSpec *conversion.HTTPRequestSpec
if clientProtocol == providerProtocol {
upstreamURL := p.BaseURL + inSpec.URL
headers := adapter.BuildHeaders(targetProvider)
if _, ok := headers["Content-Type"]; !ok {
headers["Content-Type"] = "application/json"
}
outSpec = &conversion.HTTPRequestSpec{
URL: upstreamURL,
Method: inSpec.Method,
Headers: headers,
Body: inSpec.Body,
}
} else {
outSpec, err = h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
}
resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, "")
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
for k, v := range convertedResp.Headers {
c.Header(k, v)
}
if c.GetHeader("Content-Type") == "" {
c.Header("Content-Type", "application/json")
}
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
}
// extractHeaders 从 Gin context 提取请求头
func extractHeaders(c *gin.Context) map[string]string {
headers := make(map[string]string)
for k, vs := range c.Request.Header {
if len(vs) > 0 {
headers[k] = vs[0]
}
}
return headers
}