import {type Connection, type Edge, getConnectedEdges, getIncomers, getOutgoers, type Node} from '@xyflow/react' import {clone, find, findIdx, isEqual, lpad, toStr, uuid} from 'licia' export class CheckError extends Error { readonly id: string constructor( id: number, message: string, ) { super(message) this.id = `E${lpad(toStr(id), 6, '0')}` } public toString(): string { return `${this.id}: ${this.message}` } } export const multiStartNodeError = () => new CheckError(100, '只能存在1个开始节点') export const multiEndNodeError = () => new CheckError(101, '只能存在1个结束节点') const getNodeById = (id: string, nodes: Node[]) => find(nodes, (n: Node) => isEqual(n.id, id)) // @ts-ignore export const checkAddNode: (type: string, nodes: Node[], edges: Edge[]) => void = (type, nodes, edges) => { if (isEqual(type, 'start-node') && findIdx(nodes, (node: Node) => isEqual(type, node.type)) > -1) { throw multiStartNodeError() } if (isEqual(type, 'end-node') && findIdx(nodes, (node: Node) => isEqual(type, node.type)) > -1) { throw multiEndNodeError() } } export const sourceNodeNotFoundError = () => new CheckError(200, '连线起始节点未找到') export const targetNodeNotFoundError = () => new CheckError(201, '连线目标节点未找到') export const startNodeToEndNodeError = () => new CheckError(202, '开始节点不能直连结束节点') export const nodeToSelfError = () => new CheckError(203, '节点不能直连自身') export const hasCycleError = () => new CheckError(204, '禁止流程循环') export const nodeNotOnlyToEndNode = () => new CheckError(206, '直连结束节点的节点不允许连接其他节点') export const hasRedundantEdgeError = () => new CheckError(207, '禁止出现冗余边') const hasCycle = (sourceNode: Node, targetNode: Node, nodes: Node[], edges: Edge[], visited = new Set()) => { if (visited.has(targetNode.id)) return false visited.add(targetNode.id) for (const outgoer of getOutgoers(targetNode, nodes, edges)) { if (isEqual(outgoer.id, sourceNode.id)) return true if (hasCycle(sourceNode, outgoer, nodes, edges, visited)) return true } } /* 摘自Dify的流程合法性判断 */ type ParallelInfoItem = { parallelNodeId: string depth: number isBranch?: boolean } type NodeParallelInfo = { parallelNodeId: string edgeHandleId: string depth: number } type NodeHandle = { node: Node handle: string } type NodeStreamInfo = { upstreamNodes: Set downstreamEdges: Set } const groupBy = (array: Record[], iteratee: string) => { const result: Record = {} for (const item of array) { // 获取属性值并转换为字符串键 const key = item[iteratee] if (!result[key]) { result[key] = [] } result[key].push(item) } return result } // @ts-ignore export const getParallelInfo = (nodes: Node[], edges: Edge[], parentNodeId?: string) => { // 等到有子图的时候再考虑 /*if (parentNodeId) { const parentNode = nodes.find(node => node.id === parentNodeId) if (!parentNode) throw new Error('Parent node not found') startNode = nodes.find(node => node.id === (parentNode.data as (IterationNodeType | LoopNodeType)).start_node_id) } else { startNode = nodes.find(node => isEqual(node.type, 'start_node')) }*/ let startNode = nodes.find(node => isEqual(node.type, 'start-node')) if (!startNode) throw new Error('Start node not found') const parallelList = [] as ParallelInfoItem[] const nextNodeHandles = [{node: startNode, handle: 'source'}] let hasAbnormalEdges = false const traverse = (firstNodeHandle: NodeHandle) => { const nodeEdgesSet = {} as Record> const totalEdgesSet = new Set() const nextHandles = [firstNodeHandle] const streamInfo = {} as Record const parallelListItem = { parallelNodeId: '', depth: 0, } as ParallelInfoItem const nodeParallelInfoMap = {} as Record nodeParallelInfoMap[firstNodeHandle.node.id] = { parallelNodeId: '', edgeHandleId: '', depth: 0, } while (nextHandles.length) { const currentNodeHandle = nextHandles.shift()! const {node: currentNode, handle: currentHandle = 'source'} = currentNodeHandle const currentNodeHandleKey = currentNode.id const connectedEdges = edges.filter(edge => edge.source === currentNode.id && edge.sourceHandle === currentHandle) const connectedEdgesLength = connectedEdges.length const outgoers = nodes.filter(node => connectedEdges.some(edge => edge.target === node.id)) const incomers = getIncomers(currentNode, nodes, edges) if (!streamInfo[currentNodeHandleKey]) { streamInfo[currentNodeHandleKey] = { upstreamNodes: new Set(), downstreamEdges: new Set(), } } if (nodeEdgesSet[currentNodeHandleKey]?.size > 0 && incomers.length > 1) { const newSet = new Set() for (const item of totalEdgesSet) { if (!streamInfo[currentNodeHandleKey].downstreamEdges.has(item)) newSet.add(item) } if (isEqual(nodeEdgesSet[currentNodeHandleKey], newSet)) { parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth nextNodeHandles.push({node: currentNode, handle: currentHandle}) break } } if (nodeParallelInfoMap[currentNode.id].depth > parallelListItem.depth) parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth outgoers.forEach((outgoer) => { const outgoerConnectedEdges = getConnectedEdges([outgoer], edges).filter(edge => edge.source === outgoer.id) const sourceEdgesGroup = groupBy(outgoerConnectedEdges, 'sourceHandle') const incomers = getIncomers(outgoer, nodes, edges) if (outgoers.length > 1 && incomers.length > 1) hasAbnormalEdges = true Object.keys(sourceEdgesGroup).forEach((sourceHandle) => { nextHandles.push({node: outgoer, handle: sourceHandle}) }) if (!outgoerConnectedEdges.length) nextHandles.push({node: outgoer, handle: 'source'}) const outgoerKey = outgoer.id if (!nodeEdgesSet[outgoerKey]) nodeEdgesSet[outgoerKey] = new Set() if (nodeEdgesSet[currentNodeHandleKey]) { for (const item of nodeEdgesSet[currentNodeHandleKey]) nodeEdgesSet[outgoerKey].add(item) } if (!streamInfo[outgoerKey]) { streamInfo[outgoerKey] = { upstreamNodes: new Set(), downstreamEdges: new Set(), } } if (!nodeParallelInfoMap[outgoer.id]) { nodeParallelInfoMap[outgoer.id] = { ...nodeParallelInfoMap[currentNode.id], } } if (connectedEdgesLength > 1) { const edge = connectedEdges.find(edge => edge.target === outgoer.id)! nodeEdgesSet[outgoerKey].add(edge.id) totalEdgesSet.add(edge.id) streamInfo[currentNodeHandleKey].downstreamEdges.add(edge.id) streamInfo[outgoerKey].upstreamNodes.add(currentNodeHandleKey) for (const item of streamInfo[currentNodeHandleKey].upstreamNodes) streamInfo[item].downstreamEdges.add(edge.id) if (!parallelListItem.parallelNodeId) parallelListItem.parallelNodeId = currentNode.id const prevDepth = nodeParallelInfoMap[currentNode.id].depth + 1 const currentDepth = nodeParallelInfoMap[outgoer.id].depth nodeParallelInfoMap[outgoer.id].depth = Math.max(prevDepth, currentDepth) } else { for (const item of streamInfo[currentNodeHandleKey].upstreamNodes) streamInfo[outgoerKey].upstreamNodes.add(item) nodeParallelInfoMap[outgoer.id].depth = nodeParallelInfoMap[currentNode.id].depth } }) } parallelList.push(parallelListItem) } while (nextNodeHandles.length) { const nodeHandle = nextNodeHandles.shift()! traverse(nodeHandle) } return { parallelList, hasAbnormalEdges, } } export const checkAddConnection: (connection: Connection, nodes: Node[], edges: Edge[]) => void = (connection, nodes, edges) => { let sourceNode = getNodeById(connection.source, nodes) if (!sourceNode) { throw sourceNodeNotFoundError() } let targetNode = getNodeById(connection.target, nodes) if (!targetNode) { throw targetNodeNotFoundError() } // 禁止短路整个流程 if (isEqual('start-node', sourceNode.type) && isEqual('end-node', targetNode.type)) { throw startNodeToEndNodeError() } // 禁止流程出现环,必须是有向无环图 if (isEqual(sourceNode.id, targetNode.id)) { throw nodeToSelfError() } else if (hasCycle(sourceNode, targetNode, nodes, edges)) { throw hasCycleError() } let newEdges = [...clone(edges), {...connection, id: uuid()}] let {hasAbnormalEdges} = getParallelInfo(nodes, newEdges) if (hasAbnormalEdges) { throw hasRedundantEdgeError() } } export const atLeastOneStartNodeError = () => new CheckError(300, '至少存在1个开始节点') export const atLeastOneEndNodeError = () => new CheckError(301, '至少存在1个结束节点') // @ts-ignore export const checkSave: (nodes: Node[], edges: Edge[], data: any) => void = (nodes, edges, data) => { if (nodes.filter(n => isEqual('start-node', n.type)).length < 1) { throw atLeastOneStartNodeError() } if (nodes.filter(n => isEqual('end-node', n.type)).length < 1) { throw atLeastOneEndNodeError() } }