in src/main/java/org/apache/sysds/resource/cost/CostEstimator.java [608:932]
public double parseSPInst(SPInstruction inst) throws CostEstimationException {
/* Logic for the parallelization factors:
* the given executor metrics relate to peak performance per node,
* utilizing all the resources available, but the Spark operations
* are executed by several tasks per node so the execution/read time
* per operation is the potential execution time that ca be achieved by
* using the full node resources divided by the with the number of
* nodes running tasks for reading but then divided to the actual number of
* tasks to account that if on a node not all the cores are reading
* then not the full resources are utilized.
*/
VarStats output;
if (inst instanceof ReblockSPInstruction || inst instanceof CSVReblockSPInstruction || inst instanceof LIBSVMReblockSPInstruction) {
UnarySPInstruction uinst = (UnarySPInstruction) inst;
VarStats input = getStats((uinst).input1.getName());
output = getStats((uinst).getOutputVariableName());
SparkCostUtils.assignOutputRDDStats(inst, output, input);
output.fileInfo = input.fileInfo;
// the resulting binary rdd is being hash-partitioned after the reblock
output.rddStats.hashPartitioned = true;
output.rddStats.cost = SparkCostUtils.getReblockInstTime(inst.getOpcode(), input, output, executorMetrics);
} else if (inst instanceof CheckpointSPInstruction) {
CheckpointSPInstruction cinst = (CheckpointSPInstruction) inst;
VarStats input = getStats(cinst.input1.getName());
double loadTime = loadRDDStatsAndEstimateTime(input);
output = getStats(cinst.getOutputVariableName());
SparkCostUtils.assignOutputRDDStats(inst, output, input);
output.fileInfo = input.fileInfo;
output.rddStats.checkpoint = true;
// assume the rdd object is only marked as checkpoint;
// adding spilling or serializing cost is skipped
output.rddStats.cost = loadTime;
} else if (inst instanceof RandSPInstruction) {
// Rand instruction takes no RDD input;
RandSPInstruction rinst = (RandSPInstruction) inst;
String opcode = rinst.getOpcode();
int randType = -1; // default for non-random object generation operations
if (opcode.equals(RAND_OPCODE) || opcode.equals(FRAME_OPCODE)) {
if (rinst.getMinValue() == 0d && rinst.getMaxValue() == 0d) { // empty matrix
randType = 0;
} else if (rinst.getSparsity() == 1.0 && rinst.getMinValue() == rinst.getMaxValue()) { // allocate, array fill
randType = 1;
} else { // full rand
randType = 2;
}
}
output = getStats(rinst.output.getName());
SparkCostUtils.assignOutputRDDStats(inst, output);
output.rddStats.cost = getRandInstTime(opcode, randType, output, executorMetrics);
} else if (inst instanceof AggregateUnarySPInstruction || inst instanceof AggregateUnarySketchSPInstruction) {
UnarySPInstruction auinst = (UnarySPInstruction) inst;
VarStats input = getStats((auinst).input1.getName());
double loadTime = loadRDDStatsAndEstimateTime(input);
output = getStats((auinst).getOutputVariableName());
SparkCostUtils.assignOutputRDDStats(inst, output, input);
output.rddStats.cost = loadTime + SparkCostUtils.getAggUnaryInstTime(auinst, input, output, executorMetrics);
} else if (inst instanceof IndexingSPInstruction) {
IndexingSPInstruction ixdinst = (IndexingSPInstruction) inst;
boolean isLeftCacheType = (inst instanceof MatrixIndexingSPInstruction &&
((MatrixIndexingSPInstruction) ixdinst).getLixType() == LeftIndex.LixCacheType.LEFT);
VarStats input1; // always assigned
VarStats input2 = null; // assigned only if case of indexing
double loadTime = 0;
if (ixdinst.getOpcode().toLowerCase().contains("left")) {
if (isLeftCacheType) {
input1 = getStats(ixdinst.input2.getName());
input2 = getStats(ixdinst.input1.getName());
} else {
input1 = getStats(ixdinst.input1.getName());
input2 = getStats(ixdinst.input2.getName());
}
if (ixdinst.getOpcode().equals(Opcodes.LEFT_INDEX.toString())) {
loadTime += loadRDDStatsAndEstimateTime(input2);
} else { // mapLeftIndex
loadTime += loadBroadcastVarStatsAndEstimateTime(input2);
}
} else {
input1 = getStats(ixdinst.input1.getName());
}
loadTime += loadRDDStatsAndEstimateTime(input1);
VarStats rowLower = getStatsWithDefaultScalar(ixdinst.getRowLower().getName());
VarStats rowUpper = getStatsWithDefaultScalar(ixdinst.getRowUpper().getName());
VarStats colLower = getStatsWithDefaultScalar(ixdinst.getColLower().getName());
VarStats colUpper = getStatsWithDefaultScalar(ixdinst.getColUpper().getName());
output = getStats(ixdinst.getOutputVariableName());
SparkCostUtils.assignOutputRDDStats(inst, output, input1, input2, rowLower, rowUpper, colLower, colUpper);
output.rddStats.cost = loadTime +
SparkCostUtils.getIndexingInstTime(ixdinst, input1, input2, output, driverMetrics, executorMetrics);
} else if (inst instanceof UnarySPInstruction) { // general unary handling body; put always after all the rest blocks for unary
UnarySPInstruction uinst = (UnarySPInstruction) inst;
VarStats input = getStats((uinst).input1.getName());
double loadTime = loadRDDStatsAndEstimateTime(input);
output = getStats((uinst).getOutputVariableName());
if (uinst instanceof UnaryMatrixSPInstruction || inst instanceof UnaryFrameSPInstruction) {
SparkCostUtils.assignOutputRDDStats(inst, output, input);
output.rddStats.cost = loadTime + SparkCostUtils.getUnaryInstTime(uinst.getOpcode(), input, output, executorMetrics);
} else if (uinst instanceof ReorgSPInstruction || inst instanceof MatrixReshapeSPInstruction) {
if (uinst instanceof ReorgSPInstruction && uinst.getOpcode().equals(Opcodes.SORT.toString())) {
ReorgSPInstruction reorgInst = (ReorgSPInstruction) inst;
VarStats ixRet = getStatsWithDefaultScalar(reorgInst.getIxRet().getName());
SparkCostUtils.assignOutputRDDStats(inst, output, input, ixRet);
} else {
SparkCostUtils.assignOutputRDDStats(inst, output, input);
}
output.rddStats.cost = loadTime + SparkCostUtils.getReorgInstTime(uinst, input, output, executorMetrics);
} else if (uinst instanceof TsmmSPInstruction || inst instanceof Tsmm2SPInstruction) {
SparkCostUtils.assignOutputRDDStats(inst, output, input);
output.rddStats.cost = loadTime + SparkCostUtils.getTSMMInstTime(uinst, input, output, driverMetrics, executorMetrics);
} else if (uinst instanceof CentralMomentSPInstruction) {
VarStats weights = null;
if (uinst.input3 != null) {
weights = getStats(uinst.input2.getName());
loadTime += loadRDDStatsAndEstimateTime(weights);
}
SparkCostUtils.assignOutputRDDStats(inst, output, input, weights);
output.rddStats.cost = loadTime +
SparkCostUtils.getCentralMomentInstTime((CentralMomentSPInstruction) uinst, input, weights, output, executorMetrics);
} else if (inst instanceof CastSPInstruction) {
SparkCostUtils.assignOutputRDDStats(inst, output, input);
output.rddStats.cost = loadTime + SparkCostUtils.getCastInstTime((CastSPInstruction) inst, input, output, executorMetrics);
} else if (inst instanceof QuantileSortSPInstruction) {
VarStats weights = null;
if (uinst.input2 != null) {
weights = getStats(uinst.input2.getName());
loadTime += loadRDDStatsAndEstimateTime(weights);
}
SparkCostUtils.assignOutputRDDStats(inst, output, input, weights);
output.rddStats.cost = loadTime +
SparkCostUtils.getQSortInstTime((QuantileSortSPInstruction) uinst, input, weights, output, executorMetrics);
} else {
throw new RuntimeException("Unsupported Unary Spark instruction of type " + inst.getClass().getName());
}
} else if (inst instanceof BinaryFrameFrameSPInstruction || inst instanceof BinaryFrameMatrixSPInstruction || inst instanceof BinaryMatrixMatrixSPInstruction || inst instanceof BinaryMatrixScalarSPInstruction) {
BinarySPInstruction binst = (BinarySPInstruction) inst;
VarStats input1 = getStatsWithDefaultScalar((binst).input1.getName());
VarStats input2 = getStatsWithDefaultScalar((binst).input2.getName());
// handle input rdd loading
double loadTime = loadRDDStatsAndEstimateTime(input1);
if (inst instanceof BinaryMatrixBVectorSPInstruction) {
loadTime += loadBroadcastVarStatsAndEstimateTime(input2);
} else {
loadTime += loadRDDStatsAndEstimateTime(input2);
}
output = getStats((binst).getOutputVariableName());
SparkCostUtils.assignOutputRDDStats(inst, output, input1, input2);
output.rddStats.cost = loadTime +
SparkCostUtils.getBinaryInstTime(inst, input1, input2, output, driverMetrics, executorMetrics);
} else if (inst instanceof AppendSPInstruction) {
AppendSPInstruction ainst = (AppendSPInstruction) inst;
VarStats input1 = getStats(ainst.input1.getName());
double loadTime = loadRDDStatsAndEstimateTime(input1);
VarStats input2 = getStats(ainst.input2.getName());
if (ainst instanceof AppendMSPInstruction) {
loadTime += loadBroadcastVarStatsAndEstimateTime(input2);
} else {
loadTime += loadRDDStatsAndEstimateTime(input2);
}
output = getStats(ainst.getOutputVariableName());
SparkCostUtils.assignOutputRDDStats(inst, output, input1, input2);
output.rddStats.cost = loadTime + SparkCostUtils.getAppendInstTime(ainst, input1, input2, output, driverMetrics, executorMetrics);
} else if (inst instanceof AggregateBinarySPInstruction || inst instanceof PmmSPInstruction || inst instanceof PMapmmSPInstruction || inst instanceof ZipmmSPInstruction) {
BinarySPInstruction binst = (BinarySPInstruction) inst;
VarStats input1, input2;
double loadTime = 0;
if (binst instanceof MapmmSPInstruction || binst instanceof PmmSPInstruction) {
MapMult.CacheType cacheType = binst instanceof MapmmSPInstruction?
((MapmmSPInstruction) binst).getCacheType() :
((PmmSPInstruction) binst).getCacheType();
if (cacheType.isRight()) {
input1 = getStats(binst.input1.getName());
input2 = getStats(binst.input2.getName());
} else {
input1 = getStats(binst.input2.getName());
input2 = getStats(binst.input1.getName());
}
loadTime += loadRDDStatsAndEstimateTime(input1);
loadTime += loadBroadcastVarStatsAndEstimateTime(input2);
} else {
input1 = getStats(binst.input1.getName());
input2 = getStats(binst.input2.getName());
loadTime += loadRDDStatsAndEstimateTime(input1);
loadTime += loadRDDStatsAndEstimateTime(input2);
}
output = getStats(binst.getOutputVariableName());
SparkCostUtils.assignOutputRDDStats(inst, output, input1, input2);
output.rddStats.cost = loadTime +
SparkCostUtils.getMatMulInstTime(binst, input1, input2, output, driverMetrics, executorMetrics);
} else if (inst instanceof MapmmChainSPInstruction) {
MapmmChainSPInstruction mmchaininst = (MapmmChainSPInstruction) inst;
VarStats input1 = getStats(mmchaininst.input1.getName());
VarStats input2 = getStats(mmchaininst.input1.getName());
VarStats input3 = null;
double loadTime = loadRDDStatsAndEstimateTime(input1) + loadBroadcastVarStatsAndEstimateTime(input2);
if (mmchaininst.input3 != null) {
input3 = getStats(mmchaininst.input3.getName());
loadTime += loadBroadcastVarStatsAndEstimateTime(input3);
}
output = getStats(mmchaininst.output.getName());
SparkCostUtils.assignOutputRDDStats(inst, output, input1, input2, input3);
output.rddStats.cost = loadTime +
SparkCostUtils.getMatMulChainInstTime(mmchaininst, input1, input2, input3, output, driverMetrics, executorMetrics);
} else if (inst instanceof CtableSPInstruction) {
CtableSPInstruction tableInst = (CtableSPInstruction) inst;
VarStats input1 = getStatsWithDefaultScalar(tableInst.input1.getName());
VarStats input2 = getStatsWithDefaultScalar(tableInst.input2.getName());
VarStats input3 = getStatsWithDefaultScalar(tableInst.input3.getName());
double loadTime = loadRDDStatsAndEstimateTime(input1) +
loadRDDStatsAndEstimateTime(input2) + loadRDDStatsAndEstimateTime(input3);
output = getStats(tableInst.getOutputVariableName());
VarStats outDim1 = getCTableDim(tableInst.getOutDim1());
VarStats outDim2 = getCTableDim(tableInst.getOutDim2());
// third input not relevant for assignment (dimensions inferring)
SparkCostUtils.assignOutputRDDStats(inst, output, input1, input2, outDim1, outDim2);
output.rddStats.cost = loadTime +
SparkCostUtils.getCtableInstTime(tableInst, input1, input2, input3, output, executorMetrics);
} else if (inst instanceof ParameterizedBuiltinSPInstruction) {
ParameterizedBuiltinSPInstruction paramInst = (ParameterizedBuiltinSPInstruction) inst;
VarStats input1 = getParameterizedBuiltinParamStats("target", paramInst.getParameterMap(), true); // required
double loadTime = input1 != null? loadRDDStatsAndEstimateTime(input1) : 0;
VarStats input2 = null; // optional
switch (inst.getOpcode()) {
case "rmempty":
input2 = getParameterizedBuiltinParamStats("offset", paramInst.getParameterMap(), false);
if (Boolean.parseBoolean(paramInst.getParameterMap().get("bRmEmptyBC"))) {
loadTime += input2 != null? loadBroadcastVarStatsAndEstimateTime(input2) : 0;
} else {
loadTime += input2 != null? loadRDDStatsAndEstimateTime(input2) : 0;
}
break;
case "contains":
input2 = getParameterizedBuiltinParamStats("pattern", paramInst.getParameterMap(), false);
break;
case "groupedagg":
input2 = getParameterizedBuiltinParamStats("groups", paramInst.getParameterMap(), false);
// here is needed also a third parameter in some cases
break;
}
output = getStatsWithDefaultScalar(paramInst.getOutputVariableName());
SparkCostUtils.assignOutputRDDStats(inst, output, input1);
output.rddStats.cost = loadTime + SparkCostUtils.getParameterizedBuiltinInstTime(paramInst,
input1, input2, output, driverMetrics, executorMetrics);
} else if (inst instanceof TernarySPInstruction) {
TernarySPInstruction tInst = (TernarySPInstruction) inst;
VarStats input1 = getStatsWithDefaultScalar(tInst.input1.getName());
VarStats input2 = getStatsWithDefaultScalar(tInst.input2.getName());
VarStats input3 = getStatsWithDefaultScalar(tInst.input3.getName());
double loadTime = loadRDDStatsAndEstimateTime(input1) +
loadRDDStatsAndEstimateTime(input2) + loadRDDStatsAndEstimateTime(input3);
output = getStats(tInst.getOutputVariableName());
SparkCostUtils.assignOutputRDDStats(inst, output, input1, input2, input3);
output.rddStats.cost = loadTime + SparkCostUtils.getTernaryInstTime(tInst,
input1, input2, input3, output, executorMetrics);
} else if (inst instanceof QuaternarySPInstruction) {
// NOTE: not all quaternary instructions supported yet; only
// mapwdivmm, mapsigmoid, mapwumm, mapwsloss, mapwcemm
QuaternarySPInstruction quatInst = (QuaternarySPInstruction) inst;
VarStats input1 = getStats(quatInst.input1.getName());
VarStats input2 = getStats(quatInst.input2.getName());
VarStats input3 = getStats(quatInst.input3.getName());
double loadTime = loadRDDStatsAndEstimateTime(input1) +
loadBroadcastVarStatsAndEstimateTime(input2) + loadBroadcastVarStatsAndEstimateTime(input3);
output = getStatsWithDefaultScalar(quatInst.getOutputVariableName()); // matrix or aggregated scalar
SparkCostUtils.assignOutputRDDStats(inst, output, input1, input2, input3);
output.rddStats.cost = loadTime + SparkCostUtils.getQuaternaryInstTime(quatInst,
input1, input2, input3, output, driverMetrics, executorMetrics);
} else if (inst instanceof WriteSPInstruction) {
WriteSPInstruction wInst = (WriteSPInstruction) inst;
VarStats input = getStats(wInst.input1.getName());
double loadTime = loadRDDStatsAndEstimateTime(input);
// extract and assign all needed parameters for writing a file
String fileName = wInst.getInput2().isLiteral()? wInst.getInput2().getLiteral().getStringValue() : "hdfs_file";
String dataSource = IOCostUtils.getDataSource(fileName); // "hadfs_file" -> "hdfs"
String formatString = wInst.getInput3().isLiteral()? wInst.getInput3().getLiteral().getStringValue() : "text";
input.fileInfo = new Object[] {dataSource, FileFormat.safeValueOf(formatString)};
// return time estimate here since no corresponding RDD statistics exist
return loadTime + IOCostUtils.getHadoopWriteTime(input, executorMetrics); // I/O estimate
}
// else if (inst instanceof CumulativeOffsetSPInstruction) {
//
// } else if (inst instanceof CovarianceSPInstruction) {
//
// } else if (inst instanceof QuantilePickSPInstruction) {
//
// } else if (inst instanceof AggregateTernarySPInstruction) {
//
// }
else {
throw new RuntimeException("Unsupported instruction: " + inst.getOpcode());
}
// output.rdd should be always initialized at this point
if (output.rddStats.isCollected) {
if (!output.isScalar()) {
output.allocatedMemory = OptimizerUtils.estimateSizeExactSparsity(output.characteristics);
putInMemory(output);
}
double ret = output.rddStats.cost;
output.rddStats = null;
return ret;
}
return 0;
}