private static Hop fuseBinarySubDAGToUnaryOperation()

in src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java [1423:1578]


	private static Hop fuseBinarySubDAGToUnaryOperation( Hop parent, Hop hi, int pos )
	{
		if( hi instanceof BinaryOp )
		{
			BinaryOp bop = (BinaryOp)hi;
			Hop left = hi.getInput(0);
			Hop right = hi.getInput(1);
			boolean applied = false;

			//sample proportion (sprop) operator
			if( bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX )
			{
				//by definition, either left or right or none applies. 
				//note: if there are multiple consumers on the intermediate,
				//we follow the heuristic that redundant computation is more beneficial, 
				//i.e., we still fuse but leave the intermediate for the other consumers  

				if( left instanceof BinaryOp ) //(1-X)*X
				{
					BinaryOp bleft = (BinaryOp)left;
					Hop left1 = bleft.getInput(0);
					Hop left2 = bleft.getInput(1);

					if( left1 instanceof LiteralOp &&
							HopRewriteUtils.getDoubleValue((LiteralOp)left1)==1 &&
							left2 == right && bleft.getOp() == OpOp2.MINUS  )
					{
						UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SPROP);
						HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
						HopRewriteUtils.cleanupUnreferenced(bop, left);
						hi = unary;
						applied = true;

						LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1");
					}
				}
				if( !applied && right instanceof BinaryOp ) //X*(1-X)
				{
					BinaryOp bright = (BinaryOp)right;
					Hop right1 = bright.getInput(0);
					Hop right2 = bright.getInput(1);

					if( right1 instanceof LiteralOp &&
							HopRewriteUtils.getDoubleValue((LiteralOp)right1)==1 &&
							right2 == left && bright.getOp() == OpOp2.MINUS )
					{
						UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SPROP);
						HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
						HopRewriteUtils.cleanupUnreferenced(bop, left);
						hi = unary;
						applied = true;

						LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2");
					}
				}
			}

			//sigmoid operator
			if( !applied && bop.getOp() == OpOp2.DIV && left.getDataType()==DataType.SCALAR && right.getDataType()==DataType.MATRIX
					&& left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==1 && right instanceof BinaryOp)
			{
				//note: if there are multiple consumers on the intermediate,
				//we follow the heuristic that redundant computation is more beneficial, 
				//i.e., we still fuse but leave the intermediate for the other consumers  

				BinaryOp bop2 = (BinaryOp)right;
				Hop left2 = bop2.getInput(0);
				Hop right2 = bop2.getInput(1);

				if(    bop2.getOp() == OpOp2.PLUS && left2.getDataType()==DataType.SCALAR && right2.getDataType()==DataType.MATRIX
						&& left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left2)==1 && right2 instanceof UnaryOp)
				{
					UnaryOp uop = (UnaryOp) right2;
					Hop uopin = uop.getInput(0);

					if( uop.getOp()==OpOp1.EXP )
					{
						UnaryOp unary = null;

						//Pattern 1: (1/(1 + exp(-X)) 
						if( HopRewriteUtils.isBinary(uopin, OpOp2.MINUS) ) {
							BinaryOp bop3 = (BinaryOp) uopin;
							Hop left3 = bop3.getInput(0);
							Hop right3 = bop3.getInput(1);

							if( left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left3)==0 )
								unary = HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID);
						}
						//Pattern 2: (1/(1 + exp(X)), e.g., where -(-X) has been removed by 
						//the 'remove unnecessary minus' rewrite --> reintroduce the minus
						else {
							BinaryOp minus = HopRewriteUtils.createBinaryMinus(uopin);
							unary = HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID);
						}

						if( unary != null ) {
							HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
							HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop);
							hi = unary;
							applied = true;

							LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1");
						}
					}
				}
			}

			//select positive (selp) operator (note: same initial pattern as sprop)
			if( !applied && bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX )
			{
				//by definition, either left or right or none applies. 
				//note: if there are multiple consumers on the intermediate tmp=(X>0), it's still beneficial
				//to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation 
				if( left instanceof BinaryOp ) //(X>0)*X
				{
					BinaryOp bleft = (BinaryOp)left;
					Hop left1 = bleft.getInput(0);
					Hop left2 = bleft.getInput(1);

					if( left2 instanceof LiteralOp &&
							HopRewriteUtils.getDoubleValue((LiteralOp)left2)==0 &&
							left1 == right && (bleft.getOp() == OpOp2.GREATER ) )
					{
						BinaryOp binary = HopRewriteUtils.createBinary(right, new LiteralOp(0), OpOp2.MAX);
						HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
						HopRewriteUtils.cleanupUnreferenced(bop, left);
						hi = binary;
						applied = true;

						LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0a");
					}
				}
				if( !applied && right instanceof BinaryOp ) //X*(X>0)
				{
					BinaryOp bright = (BinaryOp)right;
					Hop right1 = bright.getInput(0);
					Hop right2 = bright.getInput(1);

					if( right2 instanceof LiteralOp &&
							HopRewriteUtils.getDoubleValue((LiteralOp)right2)==0 &&
							right1 == left && bright.getOp() == OpOp2.GREATER )
					{
						BinaryOp binary = HopRewriteUtils.createBinary(left, new LiteralOp(0), OpOp2.MAX);
						HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
						HopRewriteUtils.cleanupUnreferenced(bop, left);
						hi = binary;
						applied= true;

						LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0b");
					}
				}
			}
		}

		return hi;
	}