From 111ca49815b0f77879279ee294afbd381c2771f4 Mon Sep 17 00:00:00 2001 From: lanyuanxiaoyao Date: Mon, 30 Jun 2025 00:12:51 +0800 Subject: [PATCH] =?UTF-8?q?feat(web):=20=E5=BC=95=E5=85=A5dify=E7=9A=84?= =?UTF-8?q?=E6=B5=81=E7=A8=8B=E6=A3=80=E6=9F=A5=E7=AE=97=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../client/src/pages/ai/flow/FlowChecker.tsx | 189 +++++++++++++++++- 1 file changed, 179 insertions(+), 10 deletions(-) diff --git a/service-web/client/src/pages/ai/flow/FlowChecker.tsx b/service-web/client/src/pages/ai/flow/FlowChecker.tsx index 37f7862..1c6192f 100644 --- a/service-web/client/src/pages/ai/flow/FlowChecker.tsx +++ b/service-web/client/src/pages/ai/flow/FlowChecker.tsx @@ -1,5 +1,5 @@ import {find, findIdx, isEqual, lpad, toStr} from 'licia' -import {type Connection, type Edge, getOutgoers, type Node} from '@xyflow/react' +import {type Connection, type Edge, getConnectedEdges, getIncomers, getOutgoers, type Node} from '@xyflow/react' export class CheckError extends Error { readonly id: string @@ -40,6 +40,183 @@ export const hasCycleError = () => new CheckError(204, '禁止流程循环') export const nodeNotOnlyToEndNode = () => new CheckError(206, '直连结束节点的节点不允许连接其他节点') export const hasRedundantEdgeError = () => new CheckError(207, '禁止出现冗余边') +const hasCycle = (sourceNode: Node, targetNode: Node, nodes: Node[], edges: Edge[], visited = new Set()) => { + if (visited.has(targetNode.id)) return false + visited.add(targetNode.id) + for (const outgoer of getOutgoers(targetNode, nodes, edges)) { + if (isEqual(outgoer.id, sourceNode.id)) return true + if (hasCycle(sourceNode, outgoer, nodes, edges, visited)) return true + } +} + +type ParallelInfoItem = { + parallelNodeId: string + depth: number + isBranch?: boolean +} +type NodeParallelInfo = { + parallelNodeId: string + edgeHandleId: string + depth: number +} +type NodeHandle = { + node: Node + handle: string +} +type NodeStreamInfo = { + upstreamNodes: Set + downstreamEdges: Set +} + +const getParallelInfo = (nodes: Node[], edges: Edge[]) => { + let startNode + + startNode = nodes.find(node => isEqual(node.type, 'start-node')) + if (!startNode) + throw new Error('Start node not found') + + const parallelList = [] as ParallelInfoItem[] + const nextNodeHandles = [{node: startNode, handle: 'source'}] + let hasAbnormalEdges = false + + const groupBy = (array: Record[], iteratee: string) => { + const result: Record = {} + for (const item of array) { + // 获取属性值并转换为字符串键 + const key = item[iteratee] + if (!result[key]) { + result[key] = [] + } + result[key].push(item) + } + return result + } + + const traverse = (firstNodeHandle: NodeHandle) => { + const nodeEdgesSet = {} as Record> + const totalEdgesSet = new Set() + const nextHandles = [firstNodeHandle] + const streamInfo = {} as Record + const parallelListItem = { + parallelNodeId: '', + depth: 0, + } as ParallelInfoItem + const nodeParallelInfoMap = {} as Record + nodeParallelInfoMap[firstNodeHandle.node.id] = { + parallelNodeId: '', + edgeHandleId: '', + depth: 0, + } + + while (nextHandles.length) { + const currentNodeHandle = nextHandles.shift()! + const {node: currentNode, handle: currentHandle = 'source'} = currentNodeHandle + const currentNodeHandleKey = currentNode.id + const connectedEdges = edges.filter(edge => edge.source === currentNode.id && edge.sourceHandle === currentHandle) + const connectedEdgesLength = connectedEdges.length + const outgoers = nodes.filter(node => connectedEdges.some(edge => edge.target === node.id)) + const incomers = getIncomers(currentNode, nodes, edges) + + if (!streamInfo[currentNodeHandleKey]) { + streamInfo[currentNodeHandleKey] = { + upstreamNodes: new Set(), + downstreamEdges: new Set(), + } + } + + if (nodeEdgesSet[currentNodeHandleKey]?.size > 0 && incomers.length > 1) { + const newSet = new Set() + for (const item of totalEdgesSet) { + if (!streamInfo[currentNodeHandleKey].downstreamEdges.has(item)) + newSet.add(item) + } + if (isEqual(nodeEdgesSet[currentNodeHandleKey], newSet)) { + parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth + nextNodeHandles.push({node: currentNode, handle: currentHandle}) + break + } + } + + if (nodeParallelInfoMap[currentNode.id].depth > parallelListItem.depth) + parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth + + outgoers.forEach((outgoer) => { + const outgoerConnectedEdges = getConnectedEdges([outgoer], edges).filter(edge => edge.source === outgoer.id) + const sourceEdgesGroup = groupBy(outgoerConnectedEdges, 'sourceHandle') + const incomers = getIncomers(outgoer, nodes, edges) + + if (outgoers.length > 1 && incomers.length > 1) + hasAbnormalEdges = true + + Object.keys(sourceEdgesGroup).forEach((sourceHandle) => { + nextHandles.push({node: outgoer, handle: sourceHandle}) + }) + if (!outgoerConnectedEdges.length) + nextHandles.push({node: outgoer, handle: 'source'}) + + const outgoerKey = outgoer.id + if (!nodeEdgesSet[outgoerKey]) + nodeEdgesSet[outgoerKey] = new Set() + + if (nodeEdgesSet[currentNodeHandleKey]) { + for (const item of nodeEdgesSet[currentNodeHandleKey]) + nodeEdgesSet[outgoerKey].add(item) + } + + if (!streamInfo[outgoerKey]) { + streamInfo[outgoerKey] = { + upstreamNodes: new Set(), + downstreamEdges: new Set(), + } + } + + if (!nodeParallelInfoMap[outgoer.id]) { + nodeParallelInfoMap[outgoer.id] = { + ...nodeParallelInfoMap[currentNode.id], + } + } + + if (connectedEdgesLength > 1) { + const edge = connectedEdges.find(edge => edge.target === outgoer.id)! + nodeEdgesSet[outgoerKey].add(edge.id) + totalEdgesSet.add(edge.id) + + streamInfo[currentNodeHandleKey].downstreamEdges.add(edge.id) + streamInfo[outgoerKey].upstreamNodes.add(currentNodeHandleKey) + + for (const item of streamInfo[currentNodeHandleKey].upstreamNodes) + streamInfo[item].downstreamEdges.add(edge.id) + + if (!parallelListItem.parallelNodeId) + parallelListItem.parallelNodeId = currentNode.id + + const prevDepth = nodeParallelInfoMap[currentNode.id].depth + 1 + const currentDepth = nodeParallelInfoMap[outgoer.id].depth + + nodeParallelInfoMap[outgoer.id].depth = Math.max(prevDepth, currentDepth) + } else { + for (const item of streamInfo[currentNodeHandleKey].upstreamNodes) + streamInfo[outgoerKey].upstreamNodes.add(item) + + nodeParallelInfoMap[outgoer.id].depth = nodeParallelInfoMap[currentNode.id].depth + } + }) + } + + parallelList.push(parallelListItem) + } + + while (nextNodeHandles.length) { + const nodeHandle = nextNodeHandles.shift()! + traverse(nodeHandle) + } + + return { + parallelList, + hasAbnormalEdges, + } +} + export const checkAddConnection: (connection: Connection, nodes: Node[], edges: Edge[]) => void = (connection, nodes, edges) => { let sourceNode = getNodeById(connection.source, nodes) if (!sourceNode) { @@ -55,17 +232,9 @@ export const checkAddConnection: (connection: Connection, nodes: Node[], edges: } // 禁止流程出现环,必须是有向无环图 - const hasCycle = (node: Node, visited = new Set()) => { - if (visited.has(node.id)) return false - visited.add(node.id) - for (const outgoer of getOutgoers(node, nodes, edges)) { - if (isEqual(outgoer.id, sourceNode?.id)) return true - if (hasCycle(outgoer, visited)) return true - } - } if (isEqual(sourceNode.id, targetNode.id)) { throw nodeToSelfError() - } else if (hasCycle(targetNode)) { + } else if (hasCycle(sourceNode, targetNode, nodes, edges)) { throw hasCycleError() }