in src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java [411:932]
public static long getInstNFLOP(
CPType instructionType,
String opcode,
VarStats output,
VarStats...inputs
) {
opcode = opcode.toLowerCase(); // enforce lowercase for convince
long m;
double costs = 0;
switch (instructionType) {
// types corresponding to UnaryCPInstruction
case Unary:
case Builtin: // log and log_nz only
if (output == null || inputs.length < 1)
throw new RuntimeException("Not all required arguments for Unary/Builtin operations are passed initialized");
double sparsity = inputs[0].getSparsity();
switch (opcode) {
case "!":
case "isna":
case "isnan":
case "isinf":
case "ceil":
case "floor":
costs = 1;
break;
case "abs":
case "round":
case "sign":
costs = 1 * sparsity;
break;
case "sprop":
case "sqrt":
costs = 2 * sparsity;
break;
case "exp":
costs = 18 * sparsity;
break;
case "sigmoid":
costs = 21 * sparsity;
break;
case "log":
costs = 32;
break;
case "log_nz":
case "plogp":
costs = 32 * sparsity;
break;
case "print":
case "assert":
costs = 1;
break;
case "sin":
costs = 18 * sparsity;
break;
case "cos":
costs = 22 * inputs[0].getSparsity();
break;
case "tan":
costs = 42 * inputs[0].getSparsity();
break;
case "asin":
case "sinh":
costs = 93;
break;
case "acos":
case "cosh":
costs = 103;
break;
case "atan":
case "tanh":
costs = 40;
break;
case "ucumk+":
case "ucummin":
case "ucummax":
case "ucum*":
costs = 1 * sparsity;
break;
case "ucumk+*":
costs = 2 * sparsity;
break;
case "stop":
costs = 0;
break;
case "typeof":
costs = 1;
break;
case "inverse":
costs = (4.0 / 3.0) * output.getCellsWithSparsity() * output.getCellsWithSparsity();
break;
case "cholesky":
costs = (1.0 / 3.0) * output.getCellsWithSparsity() * output.getCellsWithSparsity();
break;
case "det":
case "detectschema":
case "colnames":
throw new RuntimeException("Specific Frame operation with opcode '" + opcode + "' is not supported yet");
default:
// at the point of implementation no further supported operations
throw new DMLRuntimeException("Unary operation with opcode '" + opcode + "' is not supported by SystemDS");
}
return (long) (costs * output.getCells());
case AggregateUnary:
if (output == null || inputs.length < 1)
throw new RuntimeException("Not all required arguments for AggregateUnary operations are passed initialized");
switch (opcode) {
case "nrow":
case "ncol":
case "length":
case "exists":
case "lineage":
return DEFAULT_NFLOP_NOOP;
case "uak+":
case "uark+":
case "uack+":
costs = 4;
break;
case "uasqk+":
case "uarsqk+":
case "uacsqk+":
costs = 5;
break;
case "uamean":
case "uarmean":
case "uacmean":
costs = 7;
break;
case "uavar":
case "uarvar":
case "uacvar":
costs = 14;
break;
case "uamax":
case "uarmax":
case "uarimax":
case "uacmax":
case "uamin":
case "uarmin":
case "uarimin":
case "uacmin":
costs = 1;
break;
case "ua+":
case "uar+":
case "uac+":
case "ua*":
case "uar*":
case "uac*":
costs = 1 * output.getSparsity();
break;
// count distinct operations
case "uacd":
case "uacdr":
case "uacdc":
case "unique":
case "uniquer":
case "uniquec":
costs = 1 * output.getSparsity();
break;
case "uacdap":
case "uacdapr":
case "uacdapc":
costs = 0.5 * output.getSparsity(); // do not iterate through all the cells
break;
// aggregation over the diagonal of a square matrix
case "uatrace":
case "uaktrace":
return inputs[0].getM();
default:
// at the point of implementation no further supported operations
throw new DMLRuntimeException("AggregateUnary operation with opcode '" + opcode + "' is not supported by SystemDS");
}
// scale
if (opcode.startsWith("uar")) {
costs *= inputs[0].getM();
} else if (opcode.startsWith("uac")) {
costs *= inputs[0].getN();
} else {
costs *= inputs[0].getCells();
}
return (long) (costs * output.getCells());
case MMTSJ:
if (inputs.length < 1)
throw new RuntimeException("Not all required arguments for MMTSJ operations are passed initialized");
// reduce by factor of 4: matrix multiplication better than average FLOP count
// + multiply only upper triangular
if (opcode.equals("tsmm_left")) {
costs = inputs[0].getN() * (inputs[0].getSparsity() / 2);
} else { // tsmm/tsmm_right
costs = inputs[0].getM() * (inputs[0].getSparsity() / 2);
}
return (long) (costs * inputs[0].getCellsWithSparsity());
case Reorg:
case Reshape:
if (output == null)
throw new RuntimeException("Not all required arguments for Reorg/Reshape operations are passed initialized");
if (opcode.equals(Opcodes.SORT.toString()))
return (long) (output.getCellsWithSparsity() * (Math.log(output.getM()) / Math.log(2))); // merge sort columns (n*m*log2(m))
return output.getCellsWithSparsity();
case MatrixIndexing:
if (output == null)
throw new RuntimeException("Not all required arguments for Indexing operations are passed initialized");
return output.getCellsWithSparsity();
case MMChain:
if (inputs.length < 1)
throw new RuntimeException("Not all required arguments for MMChain operations are passed initialized");
// reduction by factor 2 because matrix mult better than average flop count
// (mmchain essentially two matrix-vector muliplications)
return (2 + 2) * inputs[0].getCellsWithSparsity() / 2;
case QSort:
if (inputs.length < 1)
throw new RuntimeException("Not all required arguments for QSort operations are passed initialized");
// mergesort since comparator used
m = inputs[0].getM();
if (opcode.equals(Opcodes.QSORT.toString()))
costs = m + m;
else // == "qsort_wts" (with weights)
costs = m * inputs[0].getSparsity();
return (long) (costs + m * (int) (Math.log(m) / Math.log(2)) + m);
case CentralMoment:
if (inputs.length < 1)
throw new RuntimeException("Not all required arguments for CentralMoment operations are passed initialized");
switch (opcode) {
case "cm_sum":
throw new RuntimeException("Undefined behaviour for CentralMoment operation of type sum");
case "cm_min":
case "cm_max":
case "cm_count":
costs = 2;
break;
case "cm_mean":
costs = 9;
break;
case "cm_variance":
case "cm_cm2":
costs = 17;
break;
case "cm_cm3":
costs = 32;
break;
case "cm_cm4":
costs = 52;
break;
case "cm_invalid":
// type INVALID used when unknown dimensions
throw new RuntimeException("CentralMoment operation of type INVALID is not supported");
default:
// at the point of implementation no further supported operations
throw new DMLRuntimeException("CentralMoment operation with type (<opcode>_<type>) '" + opcode + "' is not supported by SystemDS");
}
return (long) costs * inputs[0].getCellsWithSparsity();
case UaggOuterChain:
case Dnn:
throw new RuntimeException("CP operation type'" + instructionType + "' is not supported yet");
// types corresponding to BinaryCPInstruction
case Binary:
if (opcode.equals(Opcodes.PLUS.toString()) || opcode.equals(Opcodes.MINUS.toString())) {
if (inputs.length < 2)
throw new RuntimeException("Not all required arguments for Binary operations +/- are passed initialized");
return inputs[0].getCellsWithSparsity() + inputs[1].getCellsWithSparsity();
} else if (opcode.equals(Opcodes.SOLVE.toString())) {
if (inputs.length < 1)
throw new RuntimeException("Not all required arguments for Binary operation 'solve' are passed initialized");
return inputs[0].getCells() * inputs[0].getN();
}
if (output == null)
throw new RuntimeException("Not all required arguments for Binary operations are passed initialized");
switch (opcode) {
case "*":
case "^2":
case "*2":
case "max":
case "min":
case "-nz":
case "==":
case "!=":
case "<":
case ">":
case "<=":
case ">=":
case "&&":
case "||":
case "xor":
case "bitwand":
case "bitwor":
case "bitwxor":
case "bitwshiftl":
case "bitwshiftr":
costs = 1;
break;
case "%/%":
costs = 6;
break;
case "%%":
costs = 8;
break;
case "/":
costs = 22;
break;
case "log":
case "log_nz":
costs = 32;
break;
case "^":
costs = 16;
break;
case "1-*":
costs = 2;
break;
case "dropinvalidtype":
case "dropinvalidlength":
case "freplicate":
case "valueswap":
case "applyschema":
throw new RuntimeException("Specific Frame operation with opcode '" + opcode + "' is not supported yet");
default:
// at the point of implementation no further supported operations
throw new DMLRuntimeException("Binary operation with opcode '" + opcode + "' is not supported by SystemDS");
}
return (long) (costs * output.getCells());
case AggregateBinary:
if (output == null || inputs.length < 2)
throw new RuntimeException("Not all required arguments for AggregateBinary operations are passed initialized");
// costs represents the cost for matrix transpose
if (opcode.contains("_tl")) costs = inputs[0].getCellsWithSparsity();
if (opcode.contains("_tr")) costs = inputs[1].getCellsWithSparsity();
// else ba+*/pmm (or any of cpmm/rmm/mapmm from the Spark instructions)
// reduce by factor of 2: matrix multiplication better than average FLOP count: 2*m*n*p->m*n*p
return (long) (inputs[0].getN() * inputs[0].getSparsity()) * output.getCells() + (long) costs;
case Append:
if (inputs.length < 2)
throw new RuntimeException("Not all required arguments for Append operation is passed initialized");
return inputs[0].getCellsWithSparsity() + inputs[1].getCellsWithSparsity();
case Covariance:
if (inputs.length < 1)
throw new RuntimeException("Not all required arguments for Covariance operation is passed initialized");
return (long) (23 * inputs[0].getM() * inputs[0].getSparsity());
case QPick:
switch (opcode) {
case "qpick_iqm":
m = inputs[0].getM();
return (long) (2 * m + //sum of weights
5 * 0.25d * m + //scan to lower quantile
8 * 0.5 * m); //scan from lower to upper quantile
case "qpick_median":
case "qpick_valuepick":
case "qpick_rangepick":
throw new RuntimeException("QuantilePickCPInstruction of operation type different from IQM is not supported yet");
default:
throw new DMLRuntimeException("QPick operation with opcode '" + opcode + "' is not supported by SystemDS");
}
// types corresponding to others CPInstruction(s)
case Ternary:
if (output == null)
throw new RuntimeException("Not all required arguments for Ternary operation is passed initialized");
switch (opcode) {
case "+*":
case "-*":
return 2 * output.getCells();
case "ifelse":
return output.getCells();
case "_map":
throw new RuntimeException("Specific Frame operation with opcode '" + opcode + "' is not supported yet");
default:
throw new DMLRuntimeException("Ternary operation with opcode '" + opcode + "' is not supported by SystemDS");
}
case AggregateTernary:
if (inputs.length < 1)
throw new RuntimeException("Not all required arguments for AggregateTernary operation is passed initialized");
if (opcode.equals(Opcodes.TAKPM.toString()) || opcode.equals(Opcodes.TACKPM.toString()))
return 6 * inputs[0].getCellsWithSparsity();
throw new DMLRuntimeException("AggregateTernary operation with opcode '" + opcode + "' is not supported by SystemDS");
case Quaternary:
if (inputs.length < 1)
throw new RuntimeException("Not all required arguments for Quaternary operation is passed initialized");
if (opcode.equals(Opcodes.WSLOSS.toString()) || opcode.equals(Opcodes.WDIVMM.toString()) || opcode.equals(Opcodes.WCEMM.toString())) {
// 4 matrices used
return 4 * inputs[0].getCells();
} else if (opcode.equals(Opcodes.WSIGMOID.toString()) || opcode.equals(Opcodes.WUMM.toString())) {
// 3 matrices used
return 3 * inputs[0].getCells();
}
throw new DMLRuntimeException("Quaternary operation with opcode '" + opcode + "' is not supported by SystemDS");
case BuiltinNary:
if (output == null)
throw new RuntimeException("Not all required arguments for BuiltinNary operation is passed initialized");
switch (opcode) {
case "cbind":
case "rbind":
return output.getCellsWithSparsity();
case "nmin":
case "nmax":
case "n+":
return inputs.length * output.getCellsWithSparsity();
case "printf":
case "list":
return output.getN();
case "eval":
throw new RuntimeException("EvalNaryCPInstruction is not supported yet");
default:
throw new DMLRuntimeException("BuiltinNary operation with opcode '" + opcode + "' is not supported by SystemDS");
}
case Ctable:
if (output == null)
throw new RuntimeException("Not all required arguments for Ctable operation is passed initialized");
if (opcode.startsWith(Opcodes.CTABLE.toString())) {
// potential high inaccuracy due to unknown output column size
// and inferring bound on number of elements what could lead to high underestimation
return 3 * output.getCellsWithSparsity();
}
throw new DMLRuntimeException("Ctable operation with opcode '" + opcode + "' is not supported by SystemDS");
case PMMJ:
// currently this would never be reached since the pmm instruction uses AggregateBinary op. type
if (output == null || inputs.length < 1)
throw new RuntimeException("Not all required arguments for PMMJ operation is passed initialized");
if (opcode.equals(Opcodes.PMM.toString())) {
return (long) (inputs[0].getN() * inputs[0].getSparsity()) * output.getCells();
}
throw new DMLRuntimeException("PMMJ operation with opcode '" + opcode + "' is not supported by SystemDS");
case ParameterizedBuiltin:
// no argument validation here since the logic is not fully defined for this operation
m = inputs[0].getM();
switch (opcode) {
case "contains":
case "replace":
case "tostring":
return inputs[0].getCells();
case "nvlist":
case "cdf":
case "invcdf":
case "lowertri":
case "uppertri":
case "rexpand":
return output.getCells();
case "rmempty_rows":
return (long) (inputs[0].getM() * Math.ceil(1.0d / inputs[0].getSparsity()) / 2)
+ output.getCells();
case "rmempty_cols":
return (long) (inputs[0].getN() * Math.ceil(1.0d / inputs[0].getSparsity()) / 2)
+ output.getCells();
// opcode: "groupedagg"
case "groupedagg_count":
case "groupedagg_min":
case "groupedagg_max":
return 2 * m + m;
case "groupedagg_sum":
return 2 * m + 4 * m;
case "groupedagg_mean":
return 2 * m + 8 * m;
case "groupedagg_cm2":
return 2 * m + 16 * m;
case "groupedagg_cm3":
return 2 * m + 31 * m;
case "groupedagg_cm4":
return 2 * m + 51 * m;
case "groupedagg_variance":
return 2 * m + 16 * m;
case "groupedagg_invalid":
// type INVALID used when unknown dimensions
throw new RuntimeException("ParameterizedBuiltin operation with opcode 'groupedagg' of type INVALID is not supported");
case "tokenize":
case "transformapply":
case "transformdecode":
case "transformcolmap":
case "transformmeta":
case "autodiff":
case "paramserv":
throw new RuntimeException("ParameterizedBuiltin operation with opcode '" + opcode + "' is not supported yet");
default:
throw new DMLRuntimeException("ParameterizedBuiltin operation with opcode '" + opcode + "' is not supported by SystemDS");
}
case MultiReturnBuiltin:
if (inputs.length < 1)
throw new RuntimeException("Not all required arguments for MultiReturnBuiltin operation is passed initialized");
switch (opcode) {
case "qr":
costs = 2;
break;
case "lu":
costs = 16;
break;
case "eigen":
case "svd":
costs = 32;
break;
case "fft":
case "fft_linearized":
throw new RuntimeException("MultiReturnBuiltin operation with opcode '" + opcode + "' is not supported yet");
default:
throw new DMLRuntimeException(" MultiReturnBuiltin operation with opcode '" + opcode + "' is not supported by SystemDS");
}
// scale up the nflop value to represent that the operations are executed by a single thread only
// adapt later for fft/fft_linearized since they utilize all threads
int cpuFactor = InfrastructureAnalyzer.getLocalParallelism();
return (long) (cpuFactor * costs * inputs[0].getCells() * inputs[0].getN());
case Prefetch:
case EvictLineageCache:
case Broadcast:
case Local:
case FCall:
case NoOp:
// not directly related to computation
return 0;
case Variable:
case Rand:
case StringInit:
throw new RuntimeException(instructionType + " instructions are not handled by this method");
case MultiReturnParameterizedBuiltin: // opcodes: transformencode
case MultiReturnComplexMatrixBuiltin: // opcodes: ifft, ifft_linearized, stft, rcm
case Compression: // opcode: compress
case DeCompression: // opcode: decompress
throw new RuntimeException("CP operation type'" + instructionType + "' is not supported yet");
case TrigRemote:
case Partition:
case SpoofFused:
case Sql:
throw new RuntimeException("CP operation type'" + instructionType + "' is not planned for support");
default:
// no further supported CP types
throw new DMLRuntimeException("CP operation type'" + instructionType + "' is not supported by SystemDS");
}
}