in src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java [1423:1578]
private static Hop fuseBinarySubDAGToUnaryOperation( Hop parent, Hop hi, int pos )
{
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
Hop left = hi.getInput(0);
Hop right = hi.getInput(1);
boolean applied = false;
//sample proportion (sprop) operator
if( bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX )
{
//by definition, either left or right or none applies.
//note: if there are multiple consumers on the intermediate,
//we follow the heuristic that redundant computation is more beneficial,
//i.e., we still fuse but leave the intermediate for the other consumers
if( left instanceof BinaryOp ) //(1-X)*X
{
BinaryOp bleft = (BinaryOp)left;
Hop left1 = bleft.getInput(0);
Hop left2 = bleft.getInput(1);
if( left1 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)left1)==1 &&
left2 == right && bleft.getOp() == OpOp2.MINUS )
{
UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SPROP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1");
}
}
if( !applied && right instanceof BinaryOp ) //X*(1-X)
{
BinaryOp bright = (BinaryOp)right;
Hop right1 = bright.getInput(0);
Hop right2 = bright.getInput(1);
if( right1 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)right1)==1 &&
right2 == left && bright.getOp() == OpOp2.MINUS )
{
UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SPROP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2");
}
}
}
//sigmoid operator
if( !applied && bop.getOp() == OpOp2.DIV && left.getDataType()==DataType.SCALAR && right.getDataType()==DataType.MATRIX
&& left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==1 && right instanceof BinaryOp)
{
//note: if there are multiple consumers on the intermediate,
//we follow the heuristic that redundant computation is more beneficial,
//i.e., we still fuse but leave the intermediate for the other consumers
BinaryOp bop2 = (BinaryOp)right;
Hop left2 = bop2.getInput(0);
Hop right2 = bop2.getInput(1);
if( bop2.getOp() == OpOp2.PLUS && left2.getDataType()==DataType.SCALAR && right2.getDataType()==DataType.MATRIX
&& left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left2)==1 && right2 instanceof UnaryOp)
{
UnaryOp uop = (UnaryOp) right2;
Hop uopin = uop.getInput(0);
if( uop.getOp()==OpOp1.EXP )
{
UnaryOp unary = null;
//Pattern 1: (1/(1 + exp(-X))
if( HopRewriteUtils.isBinary(uopin, OpOp2.MINUS) ) {
BinaryOp bop3 = (BinaryOp) uopin;
Hop left3 = bop3.getInput(0);
Hop right3 = bop3.getInput(1);
if( left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left3)==0 )
unary = HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID);
}
//Pattern 2: (1/(1 + exp(X)), e.g., where -(-X) has been removed by
//the 'remove unnecessary minus' rewrite --> reintroduce the minus
else {
BinaryOp minus = HopRewriteUtils.createBinaryMinus(uopin);
unary = HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID);
}
if( unary != null ) {
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1");
}
}
}
}
//select positive (selp) operator (note: same initial pattern as sprop)
if( !applied && bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX )
{
//by definition, either left or right or none applies.
//note: if there are multiple consumers on the intermediate tmp=(X>0), it's still beneficial
//to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation
if( left instanceof BinaryOp ) //(X>0)*X
{
BinaryOp bleft = (BinaryOp)left;
Hop left1 = bleft.getInput(0);
Hop left2 = bleft.getInput(1);
if( left2 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)left2)==0 &&
left1 == right && (bleft.getOp() == OpOp2.GREATER ) )
{
BinaryOp binary = HopRewriteUtils.createBinary(right, new LiteralOp(0), OpOp2.MAX);
HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = binary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0a");
}
}
if( !applied && right instanceof BinaryOp ) //X*(X>0)
{
BinaryOp bright = (BinaryOp)right;
Hop right1 = bright.getInput(0);
Hop right2 = bright.getInput(1);
if( right2 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)right2)==0 &&
right1 == left && bright.getOp() == OpOp2.GREATER )
{
BinaryOp binary = HopRewriteUtils.createBinary(left, new LiteralOp(0), OpOp2.MAX);
HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = binary;
applied= true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0b");
}
}
}
}
return hi;
}