in spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/CompiledGraph.java [110:192]
protected CompiledGraph(StateGraph stateGraph, CompileConfig compileConfig) throws GraphStateException {
this.stateGraph = stateGraph;
this.overAllState = Objects.nonNull(stateGraph.getOverAllStateFactory())
? stateGraph.getOverAllStateFactory().create() : stateGraph.getOverAllState();
this.processedData = ProcessedNodesEdgesAndConfig.process(stateGraph, compileConfig);
// CHECK INTERRUPTIONS
for (String interruption : processedData.interruptsBefore()) {
if (!processedData.nodes().anyMatchById(interruption)) {
throw StateGraph.Errors.interruptionNodeNotExist.exception(interruption);
}
}
for (String interruption : processedData.interruptsAfter()) {
if (!processedData.nodes().anyMatchById(interruption)) {
throw StateGraph.Errors.interruptionNodeNotExist.exception(interruption);
}
}
// RE-CREATE THE EVENTUALLY UPDATED COMPILE CONFIG
this.compileConfig = CompileConfig.builder(compileConfig)
.interruptsBefore(processedData.interruptsBefore())
.interruptsAfter(processedData.interruptsAfter())
.build();
// EVALUATES NODES
for (var n : processedData.nodes().elements) {
var factory = n.actionFactory();
Objects.requireNonNull(factory, format("action factory for node id '%s' is null!", n.id()));
nodes.put(n.id(), factory.apply(compileConfig));
}
// EVALUATE EDGES
for (var e : processedData.edges().elements) {
var targets = e.targets();
if (targets.size() == 1) {
edges.put(e.sourceId(), targets.get(0));
}
else {
Supplier<Stream<EdgeValue>> parallelNodeStream = () -> targets.stream()
.filter(target -> nodes.containsKey(target.id()));
var parallelNodeEdges = parallelNodeStream.get()
.map(target -> new Edge(target.id()))
.filter(ee -> processedData.edges().elements.contains(ee))
.map(ee -> processedData.edges().elements.indexOf(ee))
.map(index -> processedData.edges().elements.get(index))
.toList();
var parallelNodeTargets = parallelNodeEdges.stream()
.map(ee -> ee.target().id())
.collect(Collectors.toSet());
if (parallelNodeTargets.size() > 1) {
var conditionalEdges = parallelNodeEdges.stream()
.filter(ee -> ee.target().value() != null)
.toList();
if (!conditionalEdges.isEmpty()) {
throw StateGraph.Errors.unsupportedConditionalEdgeOnParallelNode.exception(e.sourceId(),
conditionalEdges.stream().map(Edge::sourceId).toList());
}
throw StateGraph.Errors.illegalMultipleTargetsOnParallelNode.exception(e.sourceId(),
parallelNodeTargets);
}
var actions = parallelNodeStream.get()
// .map( target -> nodes.remove(target.id()) )
.map(target -> nodes.get(target.id()))
.toList();
var parallelNode = new ParallelNode(e.sourceId(), actions, stateGraph.keyStrategies());
nodes.put(parallelNode.id(), parallelNode.actionFactory().apply(compileConfig));
edges.put(e.sourceId(), new EdgeValue(parallelNode.id()));
edges.put(parallelNode.id(), new EdgeValue(parallelNodeTargets.iterator().next()));
}
}
}