in src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java [234:376]
public static void inferStats(CPType instType, String opcode, VarStats output, VarStats...inputs) {
switch (instType) {
case Unary:
case Builtin:
copyMissingDim(output, inputs[0]);
break;
case AggregateUnary:
if (opcode.startsWith("uar")) {
copyMissingDim(output, inputs[0].getM(), 1);
} else if (opcode.startsWith("uac")) {
copyMissingDim(output, 1, inputs[0].getN());
} else {
copyMissingDim(output, 1, 1);
}
break;
case MatrixIndexing:
if (opcode.equals("rightIndex")) {
long rowLower = (inputs[2].varName.matches("\\d+") ? Long.parseLong(inputs[2].varName) : -1);
long rowUpper = (inputs[3].varName.matches("\\d+") ? Long.parseLong(inputs[3].varName) : -1);
long colLower = (inputs[4].varName.matches("\\d+") ? Long.parseLong(inputs[4].varName) : -1);
long colUpper = (inputs[5].varName.matches("\\d+") ? Long.parseLong(inputs[5].varName) : -1);
long rowRange;
{
if (rowLower > 0 && rowUpper > 0) rowRange = rowUpper - rowLower + 1;
else if (inputs[2].varName.equals(inputs[3].varName)) rowRange = 1;
else
rowRange = inputs[0].getM() > 0 ? inputs[0].getM() : DEFAULT_INFERRED_DIM;
}
long colRange;
{
if (colLower > 0 && colUpper > 0) colRange = colUpper - colLower + 1;
else if (inputs[4].varName.equals(inputs[5].varName)) colRange = 1;
else
colRange = inputs[0].getM() > 0 ? inputs[0].getN() : DEFAULT_INFERRED_DIM;
}
copyMissingDim(output, rowRange, colRange);
} else { // leftIndex
copyMissingDim(output, inputs[0]);
}
break;
case Reorg:
switch (opcode) {
case "r'":
copyMissingDim(output, inputs[0].getN(), inputs[0].getM());
break;
case "rev":
copyMissingDim(output, inputs[0]);
break;
case "rdiag":
if (inputs[0].getN() == 1) // diagV2M
copyMissingDim(output, inputs[0].getM(), inputs[0].getM());
else // diagM2V
copyMissingDim(output, inputs[0].getM(), 1);
break;
case "rsort":
boolean ixRet = Boolean.parseBoolean(inputs[1].varName);
if (ixRet)
copyMissingDim(output, inputs[0].getM(), 1);
else
copyMissingDim(output, inputs[0]);
break;
}
break;
case Binary:
// handle case of matrix-scalar op. with the matrix being the second operand
VarStats origin = inputs[0].isScalar()? inputs[1] : inputs[0];
copyMissingDim(output, origin);
break;
case AggregateBinary:
boolean transposeLeft = false;
boolean transposeRight = false;
if (inputs.length == 4) {
transposeLeft = inputs[2] != null && Boolean.parseBoolean(inputs[2].varName);
transposeRight = inputs[3] != null && Boolean.parseBoolean(inputs[3].varName);
}
if (transposeLeft && transposeRight)
copyMissingDim(output, inputs[0].getM(), inputs[1].getM());
else if (transposeLeft)
copyMissingDim(output, inputs[0].getM(), inputs[1].getN());
else if (transposeRight)
copyMissingDim(output, inputs[0].getN(), inputs[1].getN());
else
copyMissingDim(output, inputs[0].getN(), inputs[1].getM());
break;
case ParameterizedBuiltin:
if (opcode.equals(Opcodes.RMEMPTY.toString()) || opcode.equals(Opcodes.REPLACE.toString())) {
copyMissingDim(output, inputs[0]);
} else if (opcode.equals(Opcodes.UPPERTRI.toString()) || opcode.equals(Opcodes.LOWERTRI.toString())) {
copyMissingDim(output, inputs[0].getM(), inputs[0].getM());
}
break;
case Rand:
// inferring missing output dimensions is handled exceptionally here
if (output.getCells() < 0) {
long nrows = (inputs[0].varName.matches("\\d+") ? Long.parseLong(inputs[0].varName) : -1);
long ncols = (inputs[1].varName.matches("\\d+") ? Long.parseLong(inputs[1].varName) : -1);
copyMissingDim(output, nrows, ncols);
}
break;
case Ctable:
long m = (inputs[2].varName.matches("\\d+") ? Long.parseLong(inputs[2].varName) : -1);
long n = (inputs[3].varName.matches("\\d+") ? Long.parseLong(inputs[3].varName) : -1);
if (inputs[1].isScalar()) {// Histogram
if (m < 0) m = inputs[0].getM();
if (n < 0) n = 1;
copyMissingDim(output, m, n);
} else { // transform (including "ctableexpand")
if (m < 0) m = inputs[0].getM();
if (n < 0) n = inputs[1].getCells(); // NOTE: very generous assumption, it could be revised;
copyMissingDim(output, m, n);
}
break;
case MultiReturnBuiltin:
// special case: output and inputs stats arguments are swapped: always single input with multiple outputs
VarStats FirstStats = inputs[0];
VarStats SecondStats = inputs[1];
switch (opcode) {
case "qr":
copyMissingDim(FirstStats, output.getM(), output.getM()); // Q
copyMissingDim(SecondStats, output.getM(), output.getN()); // R
break;
case "lu":
copyMissingDim(FirstStats, output.getN(), output.getN()); // L
copyMissingDim(SecondStats, output.getN(), output.getN()); // U
break;
case "eigen":
copyMissingDim(FirstStats, output.getN(), 1); // values
copyMissingDim(SecondStats, output.getN(), output.getN()); // vectors
break;
// not all opcodes supported yet
}
break;
default:
throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has no formula for inferring dimensions");
}
if (output.getCells() < 0) {
throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has incomplete formula for inferring dimensions");
}
if (output.getNNZ() < 0) {
output.characteristics.setNonZeros(output.getCells());
}
}