in src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBased.java [622:766]
private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts)
{
if( computeCosts.containsKey(current.getHopID()) )
return;
//recursively process children
for( Hop c : current.getInput() )
rGetComputeCosts(c, partition, computeCosts);
//get costs for given hop
double costs = 1;
if( current instanceof UnaryOp ) {
switch( ((UnaryOp)current).getOp() ) {
case ABS:
case ROUND:
case CEIL:
case FLOOR:
case SIGN: costs = 1; break;
case SPROP:
case SQRT: costs = 2; break;
case EXP: costs = 18; break;
case SIGMOID: costs = 21; break;
case LOG:
case LOG_NZ: costs = 32; break;
case NCOL:
case NROW:
case PRINT:
case ASSERT:
case CAST_AS_BOOLEAN:
case CAST_AS_DOUBLE:
case CAST_AS_INT:
case CAST_AS_MATRIX:
case CAST_AS_SCALAR: costs = 1; break;
case SIN: costs = 18; break;
case COS: costs = 22; break;
case TAN: costs = 42; break;
case ASIN: costs = 93; break;
case ACOS: costs = 103; break;
case ATAN: costs = 40; break;
case SINH: costs = 93; break; // TODO:
case COSH: costs = 103; break;
case TANH: costs = 40; break;
case CUMSUM:
case CUMMIN:
case CUMMAX:
case CUMPROD: costs = 1; break;
case CUMSUMPROD: costs = 2; break;
default:
LOG.warn("Cost model not "
+ "implemented yet for: "+((UnaryOp)current).getOp());
}
}
else if( current instanceof BinaryOp ) {
switch( ((BinaryOp)current).getOp() ) {
case MULT:
case PLUS:
case MINUS:
case MIN:
case MAX:
case AND:
case OR:
case EQUAL:
case NOTEQUAL:
case LESS:
case LESSEQUAL:
case GREATER:
case GREATEREQUAL:
case CBIND:
case RBIND: costs = 1; break;
case INTDIV: costs = 6; break;
case MODULUS: costs = 8; break;
case DIV: costs = 22; break;
case LOG:
case LOG_NZ: costs = 32; break;
case POW: costs = (HopRewriteUtils.isLiteralOfValue(
current.getInput().get(1), 2) ? 1 : 16); break;
case MINUS_NZ:
case MINUS1_MULT: costs = 2; break;
case MOMENT:
int type = (int) (current.getInput().get(1) instanceof LiteralOp ?
HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
switch( type ) {
case 0: costs = 1; break; //count
case 1: costs = 8; break; //mean
case 2: costs = 16; break; //cm2
case 3: costs = 31; break; //cm3
case 4: costs = 51; break; //cm4
case 5: costs = 16; break; //variance
}
break;
case COV: costs = 23; break;
default:
LOG.warn("Cost model not "
+ "implemented yet for: "+((BinaryOp)current).getOp());
}
}
else if( current instanceof TernaryOp ) {
switch( ((TernaryOp)current).getOp() ) {
case PLUS_MULT:
case MINUS_MULT: costs = 2; break;
case CTABLE: costs = 3; break;
case MOMENT:
int type = (int) (current.getInput().get(1) instanceof LiteralOp ?
HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
switch( type ) {
case 0: costs = 2; break; //count
case 1: costs = 9; break; //mean
case 2: costs = 17; break; //cm2
case 3: costs = 32; break; //cm3
case 4: costs = 52; break; //cm4
case 5: costs = 17; break; //variance
}
break;
case COV: costs = 23; break;
default:
LOG.warn("Cost model not "
+ "implemented yet for: "+((TernaryOp)current).getOp());
}
}
else if( current instanceof ParameterizedBuiltinOp ) {
costs = 1;
}
else if( current instanceof IndexingOp ) {
costs = 1;
}
else if( current instanceof ReorgOp ) {
costs = 1;
}
else if( current instanceof AggBinaryOp ) {
costs = 2; //matrix vector
}
else if( current instanceof AggUnaryOp) {
switch(((AggUnaryOp)current).getOp()) {
case SUM: costs = 4; break;
case SUM_SQ: costs = 5; break;
case MIN:
case MAX: costs = 1; break;
default:
LOG.warn("Cost model not "
+ "implemented yet for: "+((AggUnaryOp)current).getOp());
}
}
computeCosts.put(current.getHopID(), costs);
}