feat(ai-web): 完成选择节点
This commit is contained in:
@@ -1,10 +1,8 @@
|
||||
package com.lanyuanxiaoyao.service.ai.web.engine;
|
||||
|
||||
import com.lanyuanxiaoyao.service.ai.web.engine.entity.FlowGraph;
|
||||
import com.lanyuanxiaoyao.service.ai.web.engine.entity.FlowNode;
|
||||
import com.lanyuanxiaoyao.service.ai.web.engine.store.FlowStore;
|
||||
import java.util.LinkedList;
|
||||
import java.util.Queue;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import org.eclipse.collections.api.map.ImmutableMap;
|
||||
|
||||
/**
|
||||
@@ -16,14 +14,13 @@ import org.eclipse.collections.api.map.ImmutableMap;
|
||||
public class FlowExecutor {
|
||||
private final FlowStore flowStore;
|
||||
private final ImmutableMap<String, Class<? extends FlowNodeRunner>> runnerMap;
|
||||
private final Queue<FlowNode> executionQueue = new LinkedList<>();
|
||||
|
||||
public FlowExecutor(FlowStore flowStore, ImmutableMap<String, Class<? extends FlowNodeRunner>> runnerMap) {
|
||||
this.flowStore = flowStore;
|
||||
this.runnerMap = runnerMap;
|
||||
}
|
||||
|
||||
public void execute(FlowGraph graph) {
|
||||
public void execute(FlowGraph graph) throws InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException {
|
||||
var runner = new FlowGraphRunner(graph, flowStore, runnerMap);
|
||||
runner.run();
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package com.lanyuanxiaoyao.service.ai.web.engine;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.lanyuanxiaoyao.service.ai.web.engine.entity.FlowContext;
|
||||
import com.lanyuanxiaoyao.service.ai.web.engine.entity.FlowEdge;
|
||||
import com.lanyuanxiaoyao.service.ai.web.engine.entity.FlowGraph;
|
||||
import com.lanyuanxiaoyao.service.ai.web.engine.entity.FlowNode;
|
||||
import com.lanyuanxiaoyao.service.ai.web.engine.store.FlowStore;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.LinkedList;
|
||||
import java.util.Queue;
|
||||
import lombok.SneakyThrows;
|
||||
import org.eclipse.collections.api.map.ImmutableMap;
|
||||
import org.eclipse.collections.api.multimap.set.ImmutableSetMultimap;
|
||||
|
||||
@@ -17,7 +18,7 @@ import org.eclipse.collections.api.multimap.set.ImmutableSetMultimap;
|
||||
* @author lanyuanxiaoyao
|
||||
* @version 20250701
|
||||
*/
|
||||
public final class FlowGraphRunner implements Runnable {
|
||||
public final class FlowGraphRunner {
|
||||
private final FlowGraph flowGraph;
|
||||
private final FlowStore flowStore;
|
||||
private final ImmutableMap<String, Class<? extends FlowNodeRunner>> nodeRunnerClass;
|
||||
@@ -36,11 +37,8 @@ public final class FlowGraphRunner implements Runnable {
|
||||
nodeMap = flowGraph.nodes().toImmutableMap(FlowNode::id, node -> node);
|
||||
}
|
||||
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public void run() {
|
||||
public void run() throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
|
||||
flowStore.init(flowGraph);
|
||||
flowStore.updateGraphToRunning(flowGraph.id());
|
||||
|
||||
var context = new FlowContext();
|
||||
for (FlowNode node : flowGraph.nodes()) {
|
||||
@@ -48,7 +46,26 @@ public final class FlowGraphRunner implements Runnable {
|
||||
}
|
||||
while (!executionQueue.isEmpty()) {
|
||||
var node = executionQueue.poll();
|
||||
if (readyForRunning(node)) {
|
||||
process(node, context);
|
||||
}
|
||||
}
|
||||
|
||||
private void process(FlowNode node, FlowContext context) throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
|
||||
if (
|
||||
(
|
||||
// 没有入节点,即开始节点
|
||||
!nodeInputMap.containsKey(node.id())
|
||||
// 或者所有入的边状态都已经完成
|
||||
|| nodeInputMap.get(node.id()).allSatisfy(edge -> flowStore.checkEdgeStatus(flowGraph.id(), edge.id(), FlowEdge.Status.EXECUTE, FlowEdge.Status.SKIP))
|
||||
)
|
||||
// 当前节点还未执行
|
||||
&& flowStore.checkNodeStatus(flowGraph.id(), node.id(), FlowNode.Status.INITIAL)
|
||||
) {
|
||||
// 是开始节点或入的边有至少一条是「执行」
|
||||
if (
|
||||
!nodeInputMap.containsKey(node.id())
|
||||
|| nodeInputMap.get(node.id()).anySatisfy(edge -> flowStore.checkEdgeStatus(flowGraph.id(), edge.id(), FlowEdge.Status.EXECUTE))
|
||||
) {
|
||||
flowStore.updateNodeToRunning(flowGraph.id(), node.id());
|
||||
|
||||
var runnerClazz = nodeRunnerClass.get(node.type());
|
||||
@@ -57,20 +74,34 @@ public final class FlowGraphRunner implements Runnable {
|
||||
runner.setContext(context);
|
||||
runner.run();
|
||||
|
||||
if (runner instanceof FlowNodeOptionalRunner) {
|
||||
var targetPoint = ((FlowNodeOptionalRunner) runner).getTargetPoint();
|
||||
for (FlowEdge edge : nodeOutputMap.get(node.id())) {
|
||||
if (StrUtil.equals(targetPoint, edge.sourcePoint())) {
|
||||
flowStore.updateEdgeToExecute(flowGraph.id(), edge.id());
|
||||
} else {
|
||||
flowStore.updateEdgeToSkip(flowGraph.id(), edge.id());
|
||||
}
|
||||
executionQueue.offer(nodeMap.get(edge.target()));
|
||||
}
|
||||
} else {
|
||||
for (FlowEdge edge : nodeOutputMap.get(node.id())) {
|
||||
flowStore.updateEdgeToExecute(flowGraph.id(), edge.id());
|
||||
executionQueue.offer(nodeMap.get(edge.target()));
|
||||
}
|
||||
}
|
||||
|
||||
flowStore.updateNodeToFinished(flowGraph.id(), node.id());
|
||||
}
|
||||
// 所有入的边都是跳过,当前节点就跳过
|
||||
else {
|
||||
flowStore.updateNodeToSkipped(flowGraph.id(), node.id());
|
||||
|
||||
for (FlowEdge edge : nodeOutputMap.get(node.id())) {
|
||||
flowStore.updateEdgeToExecute(flowGraph.id(), edge.id());
|
||||
flowStore.updateEdgeToSkip(flowGraph.id(), edge.id());
|
||||
executionQueue.offer(nodeMap.get(edge.target()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
flowStore.updateGraphToFinished(flowGraph.id());
|
||||
}
|
||||
|
||||
private boolean readyForRunning(FlowNode node) {
|
||||
return (!nodeInputMap.containsKey(node.id()) || nodeInputMap.get(node.id()).allSatisfy(edge -> flowStore.checkEdgeStatus(flowGraph.id(), edge.id(), FlowEdge.Status.EXECUTE, FlowEdge.Status.SKIP)))
|
||||
&& flowStore.checkNodeStatus(flowGraph.id(), node.id(), FlowNode.Status.INITIAL);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.lanyuanxiaoyao.service.ai.web.engine;
|
||||
|
||||
import lombok.Getter;
|
||||
|
||||
public abstract class FlowNodeOptionalRunner extends FlowNodeRunner {
|
||||
@Getter
|
||||
private String targetPoint;
|
||||
|
||||
public abstract String runOptional();
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.targetPoint = runOptional();
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.lanyuanxiaoyao.service.ai.web.engine.entity;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import org.eclipse.collections.api.set.ImmutableSet;
|
||||
|
||||
/**
|
||||
@@ -14,22 +13,4 @@ public record FlowGraph(
|
||||
ImmutableSet<FlowNode> nodes,
|
||||
ImmutableSet<FlowEdge> edges
|
||||
) {
|
||||
public enum Status {
|
||||
INITIAL, RUNNING, FINISHED, ERROR
|
||||
}
|
||||
|
||||
public record State(
|
||||
String id,
|
||||
Status status,
|
||||
LocalDateTime startingTime,
|
||||
LocalDateTime finishedTime
|
||||
) {
|
||||
public State(String id) {
|
||||
this(id, Status.INITIAL, LocalDateTime.now(), null);
|
||||
}
|
||||
|
||||
public State(String id, Status status) {
|
||||
this(id, status, LocalDateTime.now(), null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.lanyuanxiaoyao.service.ai.web.engine.entity;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import org.eclipse.collections.api.set.ImmutableSet;
|
||||
|
||||
/**
|
||||
* 流程图中的节点
|
||||
@@ -11,9 +10,7 @@ import org.eclipse.collections.api.set.ImmutableSet;
|
||||
*/
|
||||
public record FlowNode(
|
||||
String id,
|
||||
String type,
|
||||
ImmutableSet<String> inputPoints,
|
||||
ImmutableSet<String> outputPoints
|
||||
String type
|
||||
) {
|
||||
public enum Status {
|
||||
INITIAL, RUNNING, FINISHED, SKIPPED
|
||||
|
||||
@@ -13,12 +13,6 @@ import com.lanyuanxiaoyao.service.ai.web.engine.entity.FlowNode;
|
||||
public interface FlowStore {
|
||||
void init(FlowGraph flowGraph);
|
||||
|
||||
void updateGraphToRunning(String graphId);
|
||||
|
||||
void updateGraphToFinished(String graphId);
|
||||
|
||||
void updateGraphToError(String graphId);
|
||||
|
||||
void updateNodeToRunning(String graphId, String nodeId);
|
||||
|
||||
void updateNodeToSkipped(String graphId, String nodeId);
|
||||
|
||||
@@ -16,7 +16,6 @@ import org.eclipse.collections.api.map.MutableMap;
|
||||
*/
|
||||
@Slf4j
|
||||
public class InMemoryFlowStore implements FlowStore {
|
||||
private static final MutableMap<String, FlowGraph.State> flowGraphStateMap = Maps.mutable.<String, FlowGraph.State>empty().asSynchronized();
|
||||
private static final MutableMap<String, FlowNode.State> flowNodeStateMap = Maps.mutable.<String, FlowNode.State>empty().asSynchronized();
|
||||
private static final MutableMap<String, FlowEdge.State> flowEdgeStateMap = Maps.mutable.<String, FlowEdge.State>empty().asSynchronized();
|
||||
|
||||
@@ -26,7 +25,6 @@ public class InMemoryFlowStore implements FlowStore {
|
||||
|
||||
@Override
|
||||
public void init(FlowGraph flowGraph) {
|
||||
flowGraphStateMap.put(flowGraph.id(), new FlowGraph.State(flowGraph.id()));
|
||||
for (FlowNode node : flowGraph.nodes()) {
|
||||
flowNodeStateMap.put(multiKey(flowGraph.id(), node.id()), new FlowNode.State(node.id()));
|
||||
}
|
||||
@@ -35,33 +33,6 @@ public class InMemoryFlowStore implements FlowStore {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateGraphToRunning(String graphId) {
|
||||
flowGraphStateMap.updateValue(
|
||||
graphId,
|
||||
() -> new FlowGraph.State(graphId, FlowGraph.Status.RUNNING),
|
||||
old -> new FlowGraph.State(graphId, FlowGraph.Status.RUNNING, old.startingTime(), old.finishedTime())
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateGraphToFinished(String graphId) {
|
||||
flowGraphStateMap.updateValue(
|
||||
graphId,
|
||||
() -> new FlowGraph.State(graphId, FlowGraph.Status.FINISHED),
|
||||
old -> new FlowGraph.State(graphId, FlowGraph.Status.FINISHED, old.startingTime(), old.finishedTime())
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateGraphToError(String graphId) {
|
||||
flowGraphStateMap.updateValue(
|
||||
graphId,
|
||||
() -> new FlowGraph.State(graphId, FlowGraph.Status.ERROR),
|
||||
old -> new FlowGraph.State(graphId, FlowGraph.Status.ERROR, old.startingTime(), old.finishedTime())
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateNodeToRunning(String graphId, String nodeId) {
|
||||
flowNodeStateMap.updateValue(
|
||||
@@ -128,8 +99,6 @@ public class InMemoryFlowStore implements FlowStore {
|
||||
@Override
|
||||
public void print() {
|
||||
log.info("====== Flow Store ======");
|
||||
log.info("====== Flow Graph ======");
|
||||
flowGraphStateMap.forEachKeyValue((key, value) -> log.info("{}: {}", key, value.status()));
|
||||
log.info("====== Flow Node ======");
|
||||
flowNodeStateMap.forEachKeyValue((key, value) -> log.info("{}: {}", key, value.status()));
|
||||
log.info("====== Flow Edge ======");
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.lanyuanxiaoyao.service.ai.web;
|
||||
|
||||
import com.lanyuanxiaoyao.service.ai.web.engine.FlowExecutor;
|
||||
import com.lanyuanxiaoyao.service.ai.web.engine.FlowNodeOptionalRunner;
|
||||
import com.lanyuanxiaoyao.service.ai.web.engine.FlowNodeRunner;
|
||||
import com.lanyuanxiaoyao.service.ai.web.engine.entity.FlowEdge;
|
||||
import com.lanyuanxiaoyao.service.ai.web.engine.entity.FlowGraph;
|
||||
@@ -22,28 +23,39 @@ public class TestFlow {
|
||||
var executor = new FlowExecutor(
|
||||
store,
|
||||
Maps.immutable.of(
|
||||
"plain-node", PlainNode.class
|
||||
"plain-node", PlainNode.class,
|
||||
"option-node", PlainOptionNode.class
|
||||
)
|
||||
);
|
||||
/*
|
||||
* 4 6 7
|
||||
* 1 2 5 8---3
|
||||
* \9/
|
||||
*/
|
||||
var graph = new FlowGraph(
|
||||
"graph-1",
|
||||
Sets.immutable.of(
|
||||
new FlowNode("node-1", "plain-node", Sets.immutable.empty(), Sets.immutable.of("target")),
|
||||
new FlowNode("node-2", "plain-node", Sets.immutable.of("source"), Sets.immutable.of("target")),
|
||||
new FlowNode("node-4", "plain-node", Sets.immutable.of("source"), Sets.immutable.of("target")),
|
||||
new FlowNode("node-6", "plain-node", Sets.immutable.of("source"), Sets.immutable.of("target")),
|
||||
new FlowNode("node-7", "plain-node", Sets.immutable.of("source"), Sets.immutable.of("target")),
|
||||
new FlowNode("node-5", "plain-node", Sets.immutable.of("source"), Sets.immutable.of("target")),
|
||||
new FlowNode("node-3", "plain-node", Sets.immutable.of("source"), Sets.immutable.empty())
|
||||
new FlowNode("node-1", "plain-node"),
|
||||
new FlowNode("node-2", "plain-node"),
|
||||
new FlowNode("node-4", "plain-node"),
|
||||
new FlowNode("node-6", "plain-node"),
|
||||
new FlowNode("node-7", "plain-node"),
|
||||
new FlowNode("node-5", "plain-node"),
|
||||
new FlowNode("node-8", "option-node"),
|
||||
new FlowNode("node-9", "plain-node"),
|
||||
new FlowNode("node-3", "plain-node")
|
||||
),
|
||||
Sets.immutable.of(
|
||||
new FlowEdge("edge-1", "node-1", "node-2", null, null),
|
||||
new FlowEdge("edge-2", "node-2", "node-4", null, null),
|
||||
new FlowEdge("edge-3", "node-2", "node-5", null, null),
|
||||
new FlowEdge("edge-4", "node-4", "node-6", null, null),
|
||||
new FlowEdge("edge-5", "node-6", "node-7", null, null),
|
||||
new FlowEdge("edge-6", "node-7", "node-3", null, null),
|
||||
new FlowEdge("edge-7", "node-5", "node-3", null, null)
|
||||
new FlowEdge("edge-4", "node-5", "node-8", null, null),
|
||||
new FlowEdge("edge-5", "node-8", "node-9", "yes", null),
|
||||
new FlowEdge("edge-6", "node-8", "node-3", "no", null),
|
||||
new FlowEdge("edge-7", "node-9", "node-3", null, null),
|
||||
new FlowEdge("edge-8", "node-4", "node-6", null, null),
|
||||
new FlowEdge("edge-9", "node-6", "node-7", null, null),
|
||||
new FlowEdge("edge-10", "node-7", "node-3", null, null)
|
||||
)
|
||||
);
|
||||
executor.execute(graph);
|
||||
@@ -56,4 +68,13 @@ public class TestFlow {
|
||||
log.info("run node id: {}", getNodeId());
|
||||
}
|
||||
}
|
||||
|
||||
public static class PlainOptionNode extends FlowNodeOptionalRunner {
|
||||
@Override
|
||||
public String runOptional() {
|
||||
log.info("run node id: {}", getNodeId());
|
||||
// yes / no
|
||||
return "no";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user