in spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/CompiledGraph.java [838:965]
record ProcessedNodesEdgesAndConfig(StateGraph.Nodes nodes, StateGraph.Edges edges, Set<String> interruptsBefore,
Set<String> interruptsAfter) {
/**
* Instantiates a new Processed nodes edges and config.
* @param stateGraph the state graph
* @param config the config
*/
ProcessedNodesEdgesAndConfig(StateGraph stateGraph, CompileConfig config) {
this(stateGraph.nodes, stateGraph.edges, config.interruptsBefore(), config.interruptsAfter());
}
/**
* Process processed nodes edges and config.
* @param stateGraph the state graph
* @param config the config
* @return the processed nodes edges and config
* @throws GraphStateException the graph state exception
*/
static ProcessedNodesEdgesAndConfig process(StateGraph stateGraph, CompileConfig config)
throws GraphStateException {
var subgraphNodes = stateGraph.nodes.onlySubStateGraphNodes();
if (subgraphNodes.isEmpty()) {
return new ProcessedNodesEdgesAndConfig(stateGraph, config);
}
var interruptsBefore = config.interruptsBefore();
var interruptsAfter = config.interruptsAfter();
var nodes = new StateGraph.Nodes(stateGraph.nodes.exceptSubStateGraphNodes());
var edges = new StateGraph.Edges(stateGraph.edges.elements);
for (var subgraphNode : subgraphNodes) {
var sgWorkflow = subgraphNode.subGraph();
//
// Process START Node
//
var sgEdgeStart = sgWorkflow.edges.edgeBySourceId(START).orElseThrow();
if (sgEdgeStart.isParallel()) {
throw new GraphStateException("subgraph not support start with parallel branches yet!");
}
var sgEdgeStartTarget = sgEdgeStart.target();
if (sgEdgeStartTarget.id() == null) {
throw new GraphStateException(format("the target for node '%s' is null!", subgraphNode.id()));
}
var sgEdgeStartRealTargetId = subgraphNode.formatId(sgEdgeStartTarget.id());
// Process Interruption (Before) Subgraph(s)
interruptsBefore = interruptsBefore.stream()
.map(interrupt -> Objects.equals(subgraphNode.id(), interrupt) ? sgEdgeStartRealTargetId : interrupt)
.collect(Collectors.toUnmodifiableSet());
var edgesWithSubgraphTargetId = stateGraph.edges.edgesByTargetId(subgraphNode.id());
if (edgesWithSubgraphTargetId.isEmpty()) {
throw new GraphStateException(
format("the node '%s' is not present as target in graph!", subgraphNode.id()));
}
for (var edgeWithSubgraphTargetId : edgesWithSubgraphTargetId) {
var newEdge = edgeWithSubgraphTargetId.withSourceAndTargetIdsUpdated(subgraphNode, Function.identity(),
id -> new EdgeValue((Objects.equals(id, subgraphNode.id())
? subgraphNode.formatId(sgEdgeStartTarget.id()) : id)));
edges.elements.remove(edgeWithSubgraphTargetId);
edges.elements.add(newEdge);
}
//
// Process END Nodes
//
var sgEdgesEnd = sgWorkflow.edges.edgesByTargetId(END);
var edgeWithSubgraphSourceId = stateGraph.edges.edgeBySourceId(subgraphNode.id()).orElseThrow();
if (edgeWithSubgraphSourceId.isParallel()) {
throw new GraphStateException("subgraph not support routes to parallel branches yet!");
}
// Process Interruption (After) Subgraph(s)
if (interruptsAfter.contains(subgraphNode.id())) {
var exceptionMessage = (edgeWithSubgraphSourceId.target()
.id() == null) ? "'interruption after' on subgraph is not supported yet!" : format(
"'interruption after' on subgraph is not supported yet! consider to use 'interruption before' node: '%s'",
edgeWithSubgraphSourceId.target().id());
throw new GraphStateException(exceptionMessage);
}
sgEdgesEnd.stream()
.map(e -> e.withSourceAndTargetIdsUpdated(subgraphNode, subgraphNode::formatId,
id -> (Objects.equals(id, END) ? edgeWithSubgraphSourceId.target()
: new EdgeValue(subgraphNode.formatId(id)))))
.forEach(edges.elements::add);
edges.elements.remove(edgeWithSubgraphSourceId);
//
// Process edges
//
sgWorkflow.edges.elements.stream()
.filter(e -> !Objects.equals(e.sourceId(), START))
.filter(e -> !e.anyMatchByTargetId(END))
.map(e -> e.withSourceAndTargetIdsUpdated(subgraphNode, subgraphNode::formatId,
id -> new EdgeValue(subgraphNode.formatId(id))))
.forEach(edges.elements::add);
//
// Process nodes
//
sgWorkflow.nodes.elements.stream()
.map(n -> n.withIdUpdated(subgraphNode::formatId))
.forEach(nodes.elements::add);
}
return new ProcessedNodesEdgesAndConfig(nodes, edges, interruptsBefore, interruptsAfter);
}
}