public double getTimeEstimateCPInst()

in src/main/java/org/apache/sysds/resource/cost/CostEstimator.java [405:590]


	public double getTimeEstimateCPInst(CPInstruction inst) throws CostEstimationException {
		double time = 0;
		VarStats output = null;
		if (inst instanceof VariableCPInstruction) {
			String opcode = inst.getOpcode();
			VariableCPInstruction vinst = (VariableCPInstruction) inst;
			VarStats input = null;
			if (opcode.startsWith("cast")) {
 				input = getStatsWithDefaultScalar(vinst.getInput1().getName());
				output = getStatsWithDefaultScalar(vinst.getOutput().getName());
				CPCostUtils.assignOutputMemoryStats(inst, output, input);
			}
			else if (opcode.equals(Opcodes.WRITE.toString())) {
				input = getStatsWithDefaultScalar(vinst.getInput1().getName());
				time += IOCostUtils.getFileSystemWriteTime(input, driverMetrics); // I/O estimate
			}
			time += input == null? 0 : loadCPVarStatsAndEstimateTime(input);
			time += CPCostUtils.getVariableInstTime(vinst, input, output, driverMetrics);
		}
		else if (inst instanceof UnaryCPInstruction) {
			UnaryCPInstruction uinst = (UnaryCPInstruction) inst;
			output = getStatsWithDefaultScalar(uinst.getOutput().getName());
			if (inst instanceof DataGenCPInstruction || inst instanceof StringInitCPInstruction) {
				String[] s = InstructionUtils.getInstructionParts(uinst.getInstructionString());
				VarStats rows = getStatsWithDefaultScalar(s[1]);
				VarStats cols = getStatsWithDefaultScalar(s[2]);
				CPCostUtils.assignOutputMemoryStats(inst, output, rows, cols);
				time += CPCostUtils.getDataGenCPInstTime(uinst, output, driverMetrics);
			} else {
				// UnaryCPInstruction input can be any type of object
				VarStats input = getStatsWithDefaultScalar(uinst.input1.getName());
				// a few of the unary instructions take second optional argument of type matrix
				VarStats weights = (uinst.input2 == null || uinst.input2.isScalar()) ? null : getStats(uinst.input2.getName());

				if (inst instanceof IndexingCPInstruction) {
					// weights = second input for leftIndex operations
					IndexingCPInstruction idxInst = (IndexingCPInstruction) inst;
					VarStats rowLower = getStatsWithDefaultScalar(idxInst.getRowLower().getName());
					VarStats rowUpper = getStatsWithDefaultScalar(idxInst.getRowUpper().getName());
					VarStats colLower = getStatsWithDefaultScalar(idxInst.getColLower().getName());
					VarStats colUpper = getStatsWithDefaultScalar(idxInst.getColUpper().getName());
					CPCostUtils.assignOutputMemoryStats(inst, output, input, weights, rowLower, rowUpper, colLower, colUpper);
				} else if (inst instanceof ReorgCPInstruction && inst.getOpcode().equals(Opcodes.SORT.toString())) {
					ReorgCPInstruction reorgInst = (ReorgCPInstruction) inst;
					VarStats ixRet = getStatsWithDefaultScalar(reorgInst.getIxRet().getName());
					CPCostUtils.assignOutputMemoryStats(inst, output, input, ixRet);
				} else {
					CPCostUtils.assignOutputMemoryStats(inst, output, input);
				}
				if (opcodeRequiresScan(inst.getOpcode())) {
					time += loadCPVarStatsAndEstimateTime(input);
				} // else -> // not read required
				time += weights == null ? 0 : loadCPVarStatsAndEstimateTime(weights);
				time += CPCostUtils.getUnaryInstTime(uinst, input, weights, output, driverMetrics);
			}
		}
		else if (inst instanceof BinaryCPInstruction) {
			BinaryCPInstruction binst = (BinaryCPInstruction) inst;
			VarStats input1 = getStatsWithDefaultScalar(binst.input1.getName());
			VarStats input2 = getStatsWithDefaultScalar(binst.input2.getName());
			VarStats weights = binst.input3 == null? null : getStatsWithDefaultScalar(binst.input3.getName());
			output = getStatsWithDefaultScalar(binst.output.getName());
			if (inst instanceof AggregateBinaryCPInstruction) {
				AggregateBinaryCPInstruction aggBinInst = (AggregateBinaryCPInstruction) inst;
				VarStats transposeLeft = new VarStats(String.valueOf(aggBinInst.transposeLeft), null);
				VarStats transposeRight = new VarStats(String.valueOf(aggBinInst.transposeRight), null);
				CPCostUtils.assignOutputMemoryStats(inst, output, input1, input2, transposeLeft, transposeRight);
			} else {
				CPCostUtils.assignOutputMemoryStats(inst, output, input1, input2);
			}

			time += loadCPVarStatsAndEstimateTime(input1);
			time += loadCPVarStatsAndEstimateTime(input2);
			time += weights == null? 0 : loadCPVarStatsAndEstimateTime(weights);
			time += CPCostUtils.getBinaryInstTime(binst, input1, input2, weights, output, driverMetrics);
		}
		else if (inst instanceof ParameterizedBuiltinCPInstruction) {
			if (inst instanceof ParamservBuiltinCPInstruction) {
				throw new RuntimeException("ParamservBuiltinCPInstruction is not supported for estimation");
			}
			ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction) inst;

			VarStats input1 = getParameterizedBuiltinParamStats("target", pinst.getParameterMap(), true); // required
			VarStats input2 = null; // optional
			switch (inst.getOpcode()) {
				case "rmempty":
					input2 = getParameterizedBuiltinParamStats("select", pinst.getParameterMap(), false);
					break;
				case "contains":
					input2 = getParameterizedBuiltinParamStats("pattern", pinst.getParameterMap(), false);
					break;
				case "groupedagg":
					input2 = getParameterizedBuiltinParamStats("groups", pinst.getParameterMap(), false);
					break;
			}
			output = getStatsWithDefaultScalar(pinst.getOutputVariableName());
			CPCostUtils.assignOutputMemoryStats(inst, output, input1, input2);

			time += input1 != null? loadCPVarStatsAndEstimateTime(input1) : 0;
			time += input2 != null? loadCPVarStatsAndEstimateTime(input2) : 0;
			time += CPCostUtils.getParameterizedBuiltinInstTime(pinst, input1, output, driverMetrics);
		} else if (inst instanceof MultiReturnBuiltinCPInstruction) {
			MultiReturnBuiltinCPInstruction mrbinst = (MultiReturnBuiltinCPInstruction) inst;
			VarStats input = getStats(mrbinst.input1.getName());
			VarStats[] outputs = new VarStats[mrbinst.getOutputs().size()];
			int i = 0;
			for (CPOperand operand : mrbinst.getOutputs()) {
				if (!operand.isMatrix()) {
					throw new DMLRuntimeException("MultiReturnBuiltinCPInstruction expects only matrix output objects");
				}
				VarStats current = getStats(operand.getName());
				outputs[i] = current;
				i++;
			}
			// input and outputs switched on purpose: exclusive behaviour for this instruction
			CPCostUtils.assignOutputMemoryStats(inst, input, outputs);
			for (VarStats current : outputs) putInMemory(current);

			time += loadCPVarStatsAndEstimateTime(input);
			time += CPCostUtils.getMultiReturnBuiltinInstTime(mrbinst, input, outputs, driverMetrics);
			// the only place to return directly here (output put in memory already)
			return time;
		}
		else if (inst instanceof ComputationCPInstruction) {
			if (inst instanceof MultiReturnParameterizedBuiltinCPInstruction || inst instanceof CompressionCPInstruction || inst instanceof DeCompressionCPInstruction) {
				throw new RuntimeException(inst.getClass().getName() + " is not supported for estimation");
			}
			ComputationCPInstruction cinst = (ComputationCPInstruction) inst;
			VarStats input1 = getStatsWithDefaultScalar(cinst.input1.getName()); // 1 input: AggregateTernaryCPInstruction
			// in general only the first input operand is guaranteed initialized
			// assume they can be also scalars (often operands are some literal or scalar arguments not related to the cost estimation)
			VarStats input2 = cinst.input2 == null? null : getStatsWithDefaultScalar(cinst.input2.getName()); // 2 inputs: PMMJCPInstruction
			VarStats input3 = cinst.input3 == null? null : getStatsWithDefaultScalar(cinst.input3.getName()); // 3 inputs: TernaryCPInstruction, CtableCPInstruction
			VarStats input4 = cinst.input4 == null? null : getStatsWithDefaultScalar(cinst.input4.getName()); // 4 inputs (possibly): QuaternaryCPInstruction
			output = getStatsWithDefaultScalar(cinst.getOutput().getName());
			if (inst instanceof CtableCPInstruction) {
				CtableCPInstruction tableInst = (CtableCPInstruction) inst;
				VarStats outDim1 = getCTableDim(tableInst.getOutDim1());
				VarStats outDim2 = getCTableDim(tableInst.getOutDim2());
				CPCostUtils.assignOutputMemoryStats(inst, output, input1, input2, outDim1, outDim2);
			} else {
				CPCostUtils.assignOutputMemoryStats(inst, output, input1, input2, input3, input4);
			}

			time += loadCPVarStatsAndEstimateTime(input1);
			time += input2 == null? 0 : loadCPVarStatsAndEstimateTime(input2);
			time += input3 == null? 0 : loadCPVarStatsAndEstimateTime(input3);
			time += input4 == null? 0 : loadCPVarStatsAndEstimateTime(input4);
			time += CPCostUtils.getComputationInstTime(cinst, input1, input2, input3, input4, output, driverMetrics);
		}
		else if (inst instanceof BuiltinNaryCPInstruction) {
			BuiltinNaryCPInstruction bninst = (BuiltinNaryCPInstruction) inst;
			output = getStatsWithDefaultScalar(bninst.getOutput().getName());
			// putInMemory(output);
			if (bninst instanceof ScalarBuiltinNaryCPInstruction) {
				return CPCostUtils.getBuiltinNaryInstTime(bninst, null, output, driverMetrics);
			}
			VarStats[] inputs = new VarStats[bninst.getInputs().length];
			int i = 0;
			for (CPOperand operand : bninst.getInputs()) {
				if (operand.isMatrix()) {
					VarStats input = getStatsWithDefaultScalar(operand.getName());
					time += loadCPVarStatsAndEstimateTime(input);
					inputs[i] = input;
					i++;
				}
			}
			// trim the arrays to its actual size
			inputs = Arrays.copyOf(inputs, i + 1);
			CPCostUtils.assignOutputMemoryStats(inst, output, inputs);
			time += CPCostUtils.getBuiltinNaryInstTime(bninst, inputs, output, driverMetrics);
		}
		else { // SqlCPInstruction
			throw new RuntimeException(inst.getClass().getName() + " is not supported for estimation");
		}

		if (output != null)
			putInMemory(output);
		// detection for functionality bugs
		if (time < 0) {
			throw new RuntimeException("Unexpected negative value at estimating CP instruction execution time");
		} else if (time == Double.POSITIVE_INFINITY) {
			throw new RuntimeException("Unexpected infinity value at estimating CP instruction execution time");
		}
		return time;
	}