in src/main/java/org/apache/sysds/resource/cost/CostEstimator.java [405:590]
public double getTimeEstimateCPInst(CPInstruction inst) throws CostEstimationException {
double time = 0;
VarStats output = null;
if (inst instanceof VariableCPInstruction) {
String opcode = inst.getOpcode();
VariableCPInstruction vinst = (VariableCPInstruction) inst;
VarStats input = null;
if (opcode.startsWith("cast")) {
input = getStatsWithDefaultScalar(vinst.getInput1().getName());
output = getStatsWithDefaultScalar(vinst.getOutput().getName());
CPCostUtils.assignOutputMemoryStats(inst, output, input);
}
else if (opcode.equals(Opcodes.WRITE.toString())) {
input = getStatsWithDefaultScalar(vinst.getInput1().getName());
time += IOCostUtils.getFileSystemWriteTime(input, driverMetrics); // I/O estimate
}
time += input == null? 0 : loadCPVarStatsAndEstimateTime(input);
time += CPCostUtils.getVariableInstTime(vinst, input, output, driverMetrics);
}
else if (inst instanceof UnaryCPInstruction) {
UnaryCPInstruction uinst = (UnaryCPInstruction) inst;
output = getStatsWithDefaultScalar(uinst.getOutput().getName());
if (inst instanceof DataGenCPInstruction || inst instanceof StringInitCPInstruction) {
String[] s = InstructionUtils.getInstructionParts(uinst.getInstructionString());
VarStats rows = getStatsWithDefaultScalar(s[1]);
VarStats cols = getStatsWithDefaultScalar(s[2]);
CPCostUtils.assignOutputMemoryStats(inst, output, rows, cols);
time += CPCostUtils.getDataGenCPInstTime(uinst, output, driverMetrics);
} else {
// UnaryCPInstruction input can be any type of object
VarStats input = getStatsWithDefaultScalar(uinst.input1.getName());
// a few of the unary instructions take second optional argument of type matrix
VarStats weights = (uinst.input2 == null || uinst.input2.isScalar()) ? null : getStats(uinst.input2.getName());
if (inst instanceof IndexingCPInstruction) {
// weights = second input for leftIndex operations
IndexingCPInstruction idxInst = (IndexingCPInstruction) inst;
VarStats rowLower = getStatsWithDefaultScalar(idxInst.getRowLower().getName());
VarStats rowUpper = getStatsWithDefaultScalar(idxInst.getRowUpper().getName());
VarStats colLower = getStatsWithDefaultScalar(idxInst.getColLower().getName());
VarStats colUpper = getStatsWithDefaultScalar(idxInst.getColUpper().getName());
CPCostUtils.assignOutputMemoryStats(inst, output, input, weights, rowLower, rowUpper, colLower, colUpper);
} else if (inst instanceof ReorgCPInstruction && inst.getOpcode().equals(Opcodes.SORT.toString())) {
ReorgCPInstruction reorgInst = (ReorgCPInstruction) inst;
VarStats ixRet = getStatsWithDefaultScalar(reorgInst.getIxRet().getName());
CPCostUtils.assignOutputMemoryStats(inst, output, input, ixRet);
} else {
CPCostUtils.assignOutputMemoryStats(inst, output, input);
}
if (opcodeRequiresScan(inst.getOpcode())) {
time += loadCPVarStatsAndEstimateTime(input);
} // else -> // not read required
time += weights == null ? 0 : loadCPVarStatsAndEstimateTime(weights);
time += CPCostUtils.getUnaryInstTime(uinst, input, weights, output, driverMetrics);
}
}
else if (inst instanceof BinaryCPInstruction) {
BinaryCPInstruction binst = (BinaryCPInstruction) inst;
VarStats input1 = getStatsWithDefaultScalar(binst.input1.getName());
VarStats input2 = getStatsWithDefaultScalar(binst.input2.getName());
VarStats weights = binst.input3 == null? null : getStatsWithDefaultScalar(binst.input3.getName());
output = getStatsWithDefaultScalar(binst.output.getName());
if (inst instanceof AggregateBinaryCPInstruction) {
AggregateBinaryCPInstruction aggBinInst = (AggregateBinaryCPInstruction) inst;
VarStats transposeLeft = new VarStats(String.valueOf(aggBinInst.transposeLeft), null);
VarStats transposeRight = new VarStats(String.valueOf(aggBinInst.transposeRight), null);
CPCostUtils.assignOutputMemoryStats(inst, output, input1, input2, transposeLeft, transposeRight);
} else {
CPCostUtils.assignOutputMemoryStats(inst, output, input1, input2);
}
time += loadCPVarStatsAndEstimateTime(input1);
time += loadCPVarStatsAndEstimateTime(input2);
time += weights == null? 0 : loadCPVarStatsAndEstimateTime(weights);
time += CPCostUtils.getBinaryInstTime(binst, input1, input2, weights, output, driverMetrics);
}
else if (inst instanceof ParameterizedBuiltinCPInstruction) {
if (inst instanceof ParamservBuiltinCPInstruction) {
throw new RuntimeException("ParamservBuiltinCPInstruction is not supported for estimation");
}
ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction) inst;
VarStats input1 = getParameterizedBuiltinParamStats("target", pinst.getParameterMap(), true); // required
VarStats input2 = null; // optional
switch (inst.getOpcode()) {
case "rmempty":
input2 = getParameterizedBuiltinParamStats("select", pinst.getParameterMap(), false);
break;
case "contains":
input2 = getParameterizedBuiltinParamStats("pattern", pinst.getParameterMap(), false);
break;
case "groupedagg":
input2 = getParameterizedBuiltinParamStats("groups", pinst.getParameterMap(), false);
break;
}
output = getStatsWithDefaultScalar(pinst.getOutputVariableName());
CPCostUtils.assignOutputMemoryStats(inst, output, input1, input2);
time += input1 != null? loadCPVarStatsAndEstimateTime(input1) : 0;
time += input2 != null? loadCPVarStatsAndEstimateTime(input2) : 0;
time += CPCostUtils.getParameterizedBuiltinInstTime(pinst, input1, output, driverMetrics);
} else if (inst instanceof MultiReturnBuiltinCPInstruction) {
MultiReturnBuiltinCPInstruction mrbinst = (MultiReturnBuiltinCPInstruction) inst;
VarStats input = getStats(mrbinst.input1.getName());
VarStats[] outputs = new VarStats[mrbinst.getOutputs().size()];
int i = 0;
for (CPOperand operand : mrbinst.getOutputs()) {
if (!operand.isMatrix()) {
throw new DMLRuntimeException("MultiReturnBuiltinCPInstruction expects only matrix output objects");
}
VarStats current = getStats(operand.getName());
outputs[i] = current;
i++;
}
// input and outputs switched on purpose: exclusive behaviour for this instruction
CPCostUtils.assignOutputMemoryStats(inst, input, outputs);
for (VarStats current : outputs) putInMemory(current);
time += loadCPVarStatsAndEstimateTime(input);
time += CPCostUtils.getMultiReturnBuiltinInstTime(mrbinst, input, outputs, driverMetrics);
// the only place to return directly here (output put in memory already)
return time;
}
else if (inst instanceof ComputationCPInstruction) {
if (inst instanceof MultiReturnParameterizedBuiltinCPInstruction || inst instanceof CompressionCPInstruction || inst instanceof DeCompressionCPInstruction) {
throw new RuntimeException(inst.getClass().getName() + " is not supported for estimation");
}
ComputationCPInstruction cinst = (ComputationCPInstruction) inst;
VarStats input1 = getStatsWithDefaultScalar(cinst.input1.getName()); // 1 input: AggregateTernaryCPInstruction
// in general only the first input operand is guaranteed initialized
// assume they can be also scalars (often operands are some literal or scalar arguments not related to the cost estimation)
VarStats input2 = cinst.input2 == null? null : getStatsWithDefaultScalar(cinst.input2.getName()); // 2 inputs: PMMJCPInstruction
VarStats input3 = cinst.input3 == null? null : getStatsWithDefaultScalar(cinst.input3.getName()); // 3 inputs: TernaryCPInstruction, CtableCPInstruction
VarStats input4 = cinst.input4 == null? null : getStatsWithDefaultScalar(cinst.input4.getName()); // 4 inputs (possibly): QuaternaryCPInstruction
output = getStatsWithDefaultScalar(cinst.getOutput().getName());
if (inst instanceof CtableCPInstruction) {
CtableCPInstruction tableInst = (CtableCPInstruction) inst;
VarStats outDim1 = getCTableDim(tableInst.getOutDim1());
VarStats outDim2 = getCTableDim(tableInst.getOutDim2());
CPCostUtils.assignOutputMemoryStats(inst, output, input1, input2, outDim1, outDim2);
} else {
CPCostUtils.assignOutputMemoryStats(inst, output, input1, input2, input3, input4);
}
time += loadCPVarStatsAndEstimateTime(input1);
time += input2 == null? 0 : loadCPVarStatsAndEstimateTime(input2);
time += input3 == null? 0 : loadCPVarStatsAndEstimateTime(input3);
time += input4 == null? 0 : loadCPVarStatsAndEstimateTime(input4);
time += CPCostUtils.getComputationInstTime(cinst, input1, input2, input3, input4, output, driverMetrics);
}
else if (inst instanceof BuiltinNaryCPInstruction) {
BuiltinNaryCPInstruction bninst = (BuiltinNaryCPInstruction) inst;
output = getStatsWithDefaultScalar(bninst.getOutput().getName());
// putInMemory(output);
if (bninst instanceof ScalarBuiltinNaryCPInstruction) {
return CPCostUtils.getBuiltinNaryInstTime(bninst, null, output, driverMetrics);
}
VarStats[] inputs = new VarStats[bninst.getInputs().length];
int i = 0;
for (CPOperand operand : bninst.getInputs()) {
if (operand.isMatrix()) {
VarStats input = getStatsWithDefaultScalar(operand.getName());
time += loadCPVarStatsAndEstimateTime(input);
inputs[i] = input;
i++;
}
}
// trim the arrays to its actual size
inputs = Arrays.copyOf(inputs, i + 1);
CPCostUtils.assignOutputMemoryStats(inst, output, inputs);
time += CPCostUtils.getBuiltinNaryInstTime(bninst, inputs, output, driverMetrics);
}
else { // SqlCPInstruction
throw new RuntimeException(inst.getClass().getName() + " is not supported for estimation");
}
if (output != null)
putInMemory(output);
// detection for functionality bugs
if (time < 0) {
throw new RuntimeException("Unexpected negative value at estimating CP instruction execution time");
} else if (time == Double.POSITIVE_INFINITY) {
throw new RuntimeException("Unexpected infinity value at estimating CP instruction execution time");
}
return time;
}