public static void inferStats()

in src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java [234:376]


	public static void inferStats(CPType instType, String opcode, VarStats output, VarStats...inputs) {
		switch (instType) {
			case Unary:
			case Builtin:
				copyMissingDim(output, inputs[0]);
				break;
			case AggregateUnary:
				if (opcode.startsWith("uar")) {
					copyMissingDim(output, inputs[0].getM(), 1);
				} else if (opcode.startsWith("uac")) {
					copyMissingDim(output, 1, inputs[0].getN());
				} else {
					copyMissingDim(output, 1, 1);
				}
				break;
			case MatrixIndexing:
				if (opcode.equals("rightIndex")) {
					long rowLower = (inputs[2].varName.matches("\\d+") ? Long.parseLong(inputs[2].varName) : -1);
					long rowUpper = (inputs[3].varName.matches("\\d+") ? Long.parseLong(inputs[3].varName) : -1);
					long colLower = (inputs[4].varName.matches("\\d+") ? Long.parseLong(inputs[4].varName) : -1);
					long colUpper = (inputs[5].varName.matches("\\d+") ? Long.parseLong(inputs[5].varName) : -1);

					long rowRange;
					{
						if (rowLower > 0 && rowUpper > 0) rowRange = rowUpper - rowLower + 1;
						else if (inputs[2].varName.equals(inputs[3].varName)) rowRange = 1;
						else
							rowRange = inputs[0].getM() > 0 ? inputs[0].getM() : DEFAULT_INFERRED_DIM;
					}
					long colRange;
					{
						if (colLower > 0 && colUpper > 0) colRange = colUpper - colLower + 1;
						else if (inputs[4].varName.equals(inputs[5].varName)) colRange = 1;
						else
							colRange = inputs[0].getM() > 0 ? inputs[0].getN()  : DEFAULT_INFERRED_DIM;
					}
					copyMissingDim(output, rowRange, colRange);
				} else { // leftIndex
					copyMissingDim(output, inputs[0]);
				}
				break;
			case Reorg:
				switch (opcode) {
					case "r'":
						copyMissingDim(output, inputs[0].getN(), inputs[0].getM());
						break;
					case "rev":
						copyMissingDim(output, inputs[0]);
						break;
					case "rdiag":
						if (inputs[0].getN() == 1) // diagV2M
							copyMissingDim(output, inputs[0].getM(), inputs[0].getM());
						else // diagM2V
							copyMissingDim(output, inputs[0].getM(), 1);
						break;
					case "rsort":
						boolean ixRet = Boolean.parseBoolean(inputs[1].varName);
						if (ixRet)
							copyMissingDim(output, inputs[0].getM(), 1);
						else
							copyMissingDim(output, inputs[0]);
						break;
				}
				break;
			case Binary:
				// handle case of matrix-scalar op. with the matrix being the second operand
				VarStats origin = inputs[0].isScalar()? inputs[1] : inputs[0];
				copyMissingDim(output, origin);
				break;
			case AggregateBinary:
				boolean transposeLeft = false;
				boolean transposeRight = false;
				if (inputs.length == 4) {
					transposeLeft = inputs[2] != null && Boolean.parseBoolean(inputs[2].varName);
					transposeRight = inputs[3] != null && Boolean.parseBoolean(inputs[3].varName);
				}
				if (transposeLeft && transposeRight)
					copyMissingDim(output, inputs[0].getM(), inputs[1].getM());
				else if (transposeLeft)
					copyMissingDim(output, inputs[0].getM(), inputs[1].getN());
				else if (transposeRight)
					copyMissingDim(output, inputs[0].getN(), inputs[1].getN());
				else
					copyMissingDim(output, inputs[0].getN(), inputs[1].getM());
				break;
			case ParameterizedBuiltin:
				if (opcode.equals(Opcodes.RMEMPTY.toString()) || opcode.equals(Opcodes.REPLACE.toString())) {
					copyMissingDim(output, inputs[0]);
				} else if (opcode.equals(Opcodes.UPPERTRI.toString()) || opcode.equals(Opcodes.LOWERTRI.toString())) {
					copyMissingDim(output, inputs[0].getM(), inputs[0].getM());
				}
				break;
			case Rand:
				// inferring missing output dimensions is handled exceptionally here
				if (output.getCells() < 0) {
					long nrows = (inputs[0].varName.matches("\\d+") ? Long.parseLong(inputs[0].varName) : -1);
					long ncols = (inputs[1].varName.matches("\\d+") ? Long.parseLong(inputs[1].varName) : -1);
					copyMissingDim(output, nrows, ncols);
				}
				break;
			case Ctable:
				long m = (inputs[2].varName.matches("\\d+") ? Long.parseLong(inputs[2].varName) : -1);
				long n = (inputs[3].varName.matches("\\d+") ? Long.parseLong(inputs[3].varName) : -1);
				if (inputs[1].isScalar()) {// Histogram
					if (m < 0) m = inputs[0].getM();
					if (n < 0) n = 1;
					copyMissingDim(output, m, n);
				} else { // transform (including "ctableexpand")
					if (m < 0) m = inputs[0].getM();
					if (n < 0) n = inputs[1].getCells();  // NOTE: very generous assumption, it could be revised;
					copyMissingDim(output, m, n);
				}
				break;
			case MultiReturnBuiltin:
				// special case: output and inputs stats arguments are swapped: always single input with multiple outputs
				VarStats FirstStats = inputs[0];
				VarStats SecondStats = inputs[1];
				switch (opcode) {
					case "qr":
						copyMissingDim(FirstStats, output.getM(), output.getM()); // Q
						copyMissingDim(SecondStats, output.getM(), output.getN()); // R
						break;
					case "lu":
						copyMissingDim(FirstStats, output.getN(), output.getN()); // L
						copyMissingDim(SecondStats, output.getN(), output.getN()); // U
						break;
					case "eigen":
						copyMissingDim(FirstStats, output.getN(), 1); // values
						copyMissingDim(SecondStats, output.getN(), output.getN()); // vectors
						break;
					// not all opcodes supported yet
				}
				break;
			default:
				throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has no formula for inferring dimensions");
		}
		if (output.getCells() < 0) {
			throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has incomplete formula for inferring dimensions");
		}
		if (output.getNNZ() < 0) {
			output.characteristics.setNonZeros(output.getCells());
		}
	}