in interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/meta/glogue/calcite/handler/GraphRowCountHandler.java [64:177]
public Double getRowCount(RelNode node, RelMetadataQuery mq) {
if (node instanceof GraphPattern) {
Pattern pattern = ((GraphPattern) node).getPattern();
Double countEstimate = countEstimator.estimate(pattern);
if (countEstimate != null) {
return countEstimate;
}
// try to estimate count based on existed partitions by rules
if (optPlanner instanceof VolcanoPlanner) {
RelSubset subset = ((VolcanoPlanner) optPlanner).getSubset(node);
if (subset != null) {
GraphExtendIntersect extendIntersect =
(GraphExtendIntersect) feasibleIntersects(subset);
if (extendIntersect != null) {
ExtendStep extendStep = extendIntersect.getGlogueEdge().getExtendStep();
int targetOrder = extendStep.getTargetVertexOrder();
PatternVertex target = pattern.getVertexByOrder(targetOrder);
Set<PatternEdge> adjacentEdges = pattern.getEdgesOf(target);
Pattern extendPattern = new Pattern();
List<PatternVertex> extendFromVertices = Lists.newArrayList();
for (PatternEdge edge : adjacentEdges) {
extendPattern.addVertex(edge.getSrcVertex());
extendPattern.addVertex(edge.getDstVertex());
extendPattern.addEdge(edge.getSrcVertex(), edge.getDstVertex(), edge);
extendFromVertices.add(Utils.getExtendFromVertex(edge, target));
}
return getRowCount(
(GraphPattern) subGraphPattern(extendIntersect, 0),
new GraphPattern(
node.getCluster(), node.getTraitSet(), extendPattern),
extendFromVertices,
mq);
}
GraphJoinDecomposition joinDecomposition =
(GraphJoinDecomposition) feasibleJoinDecomposition(subset);
if (joinDecomposition != null) {
Pattern buildPattern = joinDecomposition.getBuildPattern();
List<PatternVertex> jointVertices =
joinDecomposition.getJoinVertexPairs().stream()
.map(
k ->
buildPattern.getVertexByOrder(
k.getRightOrderId()))
.collect(Collectors.toList());
return getRowCount(
(GraphPattern) subGraphPattern(joinDecomposition, 0),
(GraphPattern) subGraphPattern(joinDecomposition, 1),
jointVertices,
mq);
}
}
}
double totalRowCount = 1.0d;
for (PatternEdge edge : pattern.getEdgeSet()) {
totalRowCount *= countEstimator.estimate(edge);
}
for (PatternVertex vertex : pattern.getVertexSet()) {
int degree = pattern.getEdgesOf(vertex).size();
if (degree > 0) {
totalRowCount /= Math.pow(countEstimator.estimate(vertex), degree - 1);
}
}
return totalRowCount;
} else if (node instanceof RelSubset) {
return mq.getRowCount(((RelSubset) node).getOriginal());
} else if (node instanceof GraphExtendIntersect || node instanceof GraphJoinDecomposition) {
if (optPlanner instanceof VolcanoPlanner) {
RelSubset subset = ((VolcanoPlanner) optPlanner).getSubset(node);
if (subset != null) {
// use the row count of the current pattern to estimate the communication cost
return mq.getRowCount(subset);
}
}
Pattern original =
(node instanceof GraphExtendIntersect)
? ((GraphExtendIntersect) node).getGlogueEdge().getDstPattern()
: ((GraphJoinDecomposition) node).getParentPatten();
return mq.getRowCount(
new GraphPattern(node.getCluster(), node.getTraitSet(), original));
} else if (node instanceof AbstractBindableTableScan) {
return getRowCount((AbstractBindableTableScan) node, mq);
} else if (node instanceof GraphLogicalPathExpand) {
return node.estimateRowCount(mq);
} else if (node instanceof GraphPhysicalExpand) {
return node.estimateRowCount(mq);
} else if (node instanceof GraphPhysicalGetV) {
return node.estimateRowCount(mq);
} else if (node instanceof MultiJoin) {
GraphOptCluster optCluster = (GraphOptCluster) node.getCluster();
RelOptCost cachedCost = optCluster.getLocalState().getCachedCost();
if (cachedCost != null) {
return cachedCost.getRows();
}
} else if (node instanceof Join) {
GraphOptCluster optCluster = (GraphOptCluster) node.getCluster();
RelOptCost cachedCost = optCluster.getLocalState().getCachedCost();
return cachedCost != null
? cachedCost.getRows()
: mdRowCount.getRowCount((Join) node, mq);
} else if (node instanceof Union) {
return mdRowCount.getRowCount((Union) node, mq);
} else if (node instanceof Filter) {
return mdRowCount.getRowCount((Filter) node, mq);
} else if (node instanceof Aggregate) {
return mdRowCount.getRowCount((Aggregate) node, mq);
} else if (node instanceof Sort) {
return mdRowCount.getRowCount((Sort) node, mq);
} else if (node instanceof Project) {
return mdRowCount.getRowCount((Project) node, mq);
} else if (node instanceof CommonTableScan) {
return mdRowCount.getRowCount((CommonTableScan) node, mq);
}
throw new IllegalArgumentException("can not estimate row count for the node=" + node);
}