private static Hop simplifyWeightedDivMM()

in src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java [1786:2097]


	private static Hop simplifyWeightedDivMM(Hop parent, Hop hi, int pos) {
		Hop hnew = null;
		boolean appliedPattern = false;
		
		//left/right patterns rooted by 'ab - b(div)' or 'ab - b(mult)'
		//note: we do not rewrite t(X)%*%(w*(X%*%v)) where w and v are vectors (see mmchain ops) 
		if( HopRewriteUtils.isMatrixMultiply(hi)  
			&& (hi.getInput(0) instanceof BinaryOp
			&& HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput(0)).getOp(), LOOKUP_VALID_WDIVMM_BINARY)
			|| hi.getInput(1) instanceof BinaryOp 
			&& hi.getDim2() > 1 //not applied for vector-vector mult
			&& HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput(1)).getOp(), LOOKUP_VALID_WDIVMM_BINARY)) ) 
		{
			Hop left = hi.getInput(0);
			Hop right = hi.getInput(1);
			
			//Pattern 1) t(U) %*% (W/(U%*%t(V)))
			//alternative pattern: t(U) %*% (W*(U%*%t(V)))
			if( right instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)right).getOp(),LOOKUP_VALID_WDIVMM_BINARY)	
				&& HopRewriteUtils.isEqualSize(right.getInput(0), right.getInput(1)) //prevent mv
				&& HopRewriteUtils.isOuterProductLikeMM(right.getInput(1))
				&& HopRewriteUtils.isSingleBlock(right.getInput(1).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
			{
				Hop W = right.getInput(0); 
				Hop U = right.getInput(1).getInput(0);
				Hop V = right.getInput(1).getInput(1);
				
				if( HopRewriteUtils.isTransposeOfItself(left, U) ) 
				{
					if( !HopRewriteUtils.isTransposeOperation(V) )
						V = HopRewriteUtils.createTranspose(V);
					else 
						V = V.getInput(0);
					
					boolean mult = ((BinaryOp)right).getOp() == OpOp2.MULT;
					hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, 
							  OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 1, mult, false);
					hnew.setBlocksize(W.getBlocksize());
					hnew.refreshSizeInformation();
					
					//add output transpose for efficient target indexing (redundant t() removed by other rewrites)
					hnew = HopRewriteUtils.createTranspose(hnew);
					
					appliedPattern = true;
					LOG.debug("Applied simplifyWeightedDivMM1 (line "+hi.getBeginLine()+")");
				}
			}	
			
			//Pattern 1e) t(U) %*% (W/(U%*%t(V) + x))
			if( !appliedPattern
				&& HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[1]) //DIV
				&& HopRewriteUtils.isEqualSize(right.getInput(0), right.getInput(1)) //prevent mv
				&& HopRewriteUtils.isBinary(right.getInput(1), OpOp2.PLUS)
				&& right.getInput(1).getInput(1).getDataType() == DataType.SCALAR
				&& HopRewriteUtils.isOuterProductLikeMM(right.getInput(1).getInput(0))
				&& HopRewriteUtils.isSingleBlock(right.getInput(1).getInput(0).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
			{
				Hop W = right.getInput(0); 
				Hop U = right.getInput(1).getInput(0).getInput(0);
				Hop V = right.getInput(1).getInput(0).getInput(1);
				Hop X = right.getInput(1).getInput(1);
				
				if( HopRewriteUtils.isTransposeOfItself(left, U) ) 
				{
					if( !HopRewriteUtils.isTransposeOperation(V) )
						V = HopRewriteUtils.createTranspose(V);
					else 
						V = V.getInput(0);
					
					hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, 
							  OpOp4.WDIVMM, W, U, V, X, 3, false, false); // 3=>DIV_LEFT_EPS
					hnew.setBlocksize(W.getBlocksize());
					hnew.refreshSizeInformation();
					
					//add output transpose for efficient target indexing (redundant t() removed by other rewrites)
					hnew = HopRewriteUtils.createTranspose(hnew);
					
					appliedPattern = true;
					LOG.debug("Applied simplifyWeightedDivMM1e (line "+hi.getBeginLine()+")");
				}
			}	
			
			//Pattern 2) (W/(U%*%t(V))) %*% V
			//alternative pattern: (W*(U%*%t(V))) %*% V
			if( !appliedPattern
				&& left instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)left).getOp(), LOOKUP_VALID_WDIVMM_BINARY)	
				&& HopRewriteUtils.isEqualSize(left.getInput(0), left.getInput(1)) //prevent mv
				&& HopRewriteUtils.isOuterProductLikeMM(left.getInput(1))
				&& HopRewriteUtils.isSingleBlock(left.getInput(1).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
			{
				Hop W = left.getInput(0); 
				Hop U = left.getInput(1).getInput(0);
				Hop V = left.getInput(1).getInput(1);
				
				if( HopRewriteUtils.isTransposeOfItself(right, V) ) 
				{
					if( !HopRewriteUtils.isTransposeOperation(V) )
						V = right;
					else 
						V = V.getInput(0);
					
					boolean mult = ((BinaryOp)left).getOp() == OpOp2.MULT;
					hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, 
							  OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 2, mult, false);
					hnew.setBlocksize(W.getBlocksize());
					hnew.refreshSizeInformation();
					
					appliedPattern = true;
					LOG.debug("Applied simplifyWeightedDivMM2 (line "+hi.getBeginLine()+")");
				}
			}
			
			//Pattern 2e) (W/(U%*%t(V) + x)) %*% V
			if( !appliedPattern
				&& HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[1]) //DIV
				&& HopRewriteUtils.isEqualSize(left.getInput(0), left.getInput(1)) //prevent mv
				&& HopRewriteUtils.isBinary(left.getInput(1), OpOp2.PLUS)
				&& left.getInput(1).getInput(1).getDataType() == DataType.SCALAR
				&& HopRewriteUtils.isOuterProductLikeMM(left.getInput(1).getInput(0))
				&& HopRewriteUtils.isSingleBlock(left.getInput(1).getInput(0).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
			{
				Hop W = left.getInput(0); 
				Hop U = left.getInput(1).getInput(0).getInput(0);
				Hop V = left.getInput(1).getInput(0).getInput(1);
				Hop X = left.getInput(1).getInput(1);
				
				if( HopRewriteUtils.isTransposeOfItself(right, V) ) 
				{
					if( !HopRewriteUtils.isTransposeOperation(V) )
						V = right;
					else 
						V = V.getInput(0);
					
					hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, 
							  OpOp4.WDIVMM, W, U, V, X, 4, false, false); // 4=>DIV_RIGHT_EPS
					hnew.setBlocksize(W.getBlocksize());
					hnew.refreshSizeInformation();
					
					appliedPattern = true;
					LOG.debug("Applied simplifyWeightedDivMM2e (line "+hi.getBeginLine()+")");	
				}
			}
			
			//Pattern 3) t(U) %*% ((X!=0)*(U%*%t(V)-X))
			if( !appliedPattern
				&& HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
				&& HopRewriteUtils.isBinary(right.getInput(1), OpOp2.MINUS)	
				&& HopRewriteUtils.isOuterProductLikeMM(right.getInput(1).getInput(0))
				&& right.getInput(1).getInput(1).getDataType() == DataType.MATRIX
				&& HopRewriteUtils.isSingleBlock(right.getInput(1).getInput(0).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
			{
				Hop W = right.getInput(0); 
				Hop U = right.getInput(1).getInput(0).getInput(0);
				Hop V = right.getInput(1).getInput(0).getInput(1);
				Hop X = right.getInput(1).getInput(1);
				
				if(    HopRewriteUtils.isNonZeroIndicator(W, X)        //W-X constraint
				    && HopRewriteUtils.isTransposeOfItself(left, U) )  //t(U)-U constraint
				{
					if( !HopRewriteUtils.isTransposeOperation(V) )
						V = HopRewriteUtils.createTranspose(V);
					else 
						V = V.getInput(0);
					
					hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, 
							  OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 1, true, true);
					hnew.setBlocksize(W.getBlocksize());
					hnew.refreshSizeInformation();
					
					//add output transpose for efficient target indexing (redundant t() removed by other rewrites)
					hnew = HopRewriteUtils.createTranspose(hnew);
					
					appliedPattern = true;
					LOG.debug("Applied simplifyWeightedDivMM3 (line "+hi.getBeginLine()+")");
				}
			}	
			
			//Pattern 4) ((X!=0)*(U%*%t(V)-X)) %*% V
			if( !appliedPattern
				&& HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
				&& HopRewriteUtils.isBinary(left.getInput(1), OpOp2.MINUS)
				&& HopRewriteUtils.isOuterProductLikeMM(left.getInput(1).getInput(0))
				&& left.getInput(1).getInput(1).getDataType() == DataType.MATRIX
				&& HopRewriteUtils.isSingleBlock(left.getInput(1).getInput(0).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
			{
				Hop W = left.getInput(0); 
				Hop U = left.getInput(1).getInput(0).getInput(0);
				Hop V = left.getInput(1).getInput(0).getInput(1);
				Hop X = left.getInput(1).getInput(1);
				
				if(    HopRewriteUtils.isNonZeroIndicator(W, X)        //W-X constraint
					&& HopRewriteUtils.isTransposeOfItself(right, V) )  //V-t(V) constraint
				{
					if( !HopRewriteUtils.isTransposeOperation(V) )
						V = right;
					else 
						V = V.getInput(0);
					
					hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, 
							  OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 2, true, true);
					hnew.setBlocksize(W.getBlocksize());
					hnew.refreshSizeInformation();
					
					appliedPattern = true;
					LOG.debug("Applied simplifyWeightedDivMM4 (line "+hi.getBeginLine()+")");
				}
			}
			
			//Pattern 5) t(U) %*% (W*(U%*%t(V)-X))
			if( !appliedPattern
				&& HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
				&& HopRewriteUtils.isBinary(right.getInput(1), OpOp2.MINUS)	
				&& HopRewriteUtils.isOuterProductLikeMM(right.getInput(1).getInput(0))
				&& right.getInput(1).getInput(1).getDataType() == DataType.MATRIX
				&& HopRewriteUtils.isSingleBlock(right.getInput(1).getInput(0).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
			{
				Hop W = right.getInput(0); 
				Hop U = right.getInput(1).getInput(0).getInput(0);
				Hop V = right.getInput(1).getInput(0).getInput(1);
				Hop X = right.getInput(1).getInput(1);
				
				if( HopRewriteUtils.isTransposeOfItself(left, U) )  //t(U)-U constraint
				{
					if( !HopRewriteUtils.isTransposeOperation(V) )
						V = HopRewriteUtils.createTranspose(V);
					else 
						V = V.getInput(0);
					
					//note: x and w exchanged compared to patterns 1-4, 7
					hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, 
							  OpOp4.WDIVMM, W, U, V, X, 1, true, true);
					hnew.setBlocksize(W.getBlocksize());
					hnew.refreshSizeInformation();
					
					//add output transpose for efficient target indexing (redundant t() removed by other rewrites)
					hnew = HopRewriteUtils.createTranspose(hnew);
					
					appliedPattern = true;
					LOG.debug("Applied simplifyWeightedDivMM5 (line "+hi.getBeginLine()+")");
				}
			}	
			
			//Pattern 6) (W*(U%*%t(V)-X)) %*% V
			if( !appliedPattern
				&& HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT	
				&& HopRewriteUtils.isBinary(left.getInput(1), OpOp2.MINUS)	
				&& HopRewriteUtils.isOuterProductLikeMM(left.getInput(1).getInput(0))
				&& left.getInput(1).getInput(1).getDataType() == DataType.MATRIX
				&& HopRewriteUtils.isSingleBlock(left.getInput(1).getInput(0).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
			{
				Hop W = left.getInput(0); 
				Hop U = left.getInput(1).getInput(0).getInput(0);
				Hop V = left.getInput(1).getInput(0).getInput(1);
				Hop X = left.getInput(1).getInput(1);
				
				if( HopRewriteUtils.isTransposeOfItself(right, V) )  //V-t(V) constraint
				{
					if( !HopRewriteUtils.isTransposeOperation(V) )
						V = right;
					else 
						V = V.getInput(0);
					
					//note: x and w exchanged compared to patterns 1-4, 7
					hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, 
							  OpOp4.WDIVMM, W, U, V, X, 2, true, true);
					hnew.setBlocksize(W.getBlocksize());
					hnew.refreshSizeInformation();
					
					appliedPattern = true;
					LOG.debug("Applied simplifyWeightedDivMM6 (line "+hi.getBeginLine()+")");
				}
			}
		}
		
		//Pattern 7) (W*(U%*%t(V)))
		if( !appliedPattern
			&& HopRewriteUtils.isBinary(hi, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT	
			&& HopRewriteUtils.isEqualSize(hi.getInput(0), hi.getInput(1)) //prevent mv
			&& hi.getDim2() > 1 //not applied for vector-vector mult
			&& hi.getInput(0).getDataType() == DataType.MATRIX 
			&& hi.getInput(0).getDim2() > hi.getInput(0).getBlocksize()
			&& HopRewriteUtils.isOuterProductLikeMM(hi.getInput(1))
			&& (((AggBinaryOp) hi.getInput(1)).checkMapMultChain() == ChainType.NONE || hi.getInput(1).getInput(1).getDim2() > 1) //no mmchain
			&& HopRewriteUtils.isSingleBlock(hi.getInput(1).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
		{
			Hop W = hi.getInput(0); 
			Hop U = hi.getInput(1).getInput(0);
			Hop V = hi.getInput(1).getInput(1);
			
			//for this basic pattern, we're more conservative and only apply wdivmm if
			//W is sparse and U/V unknown or dense; or if U/V are dense
			if( (HopRewriteUtils.isSparse(W) && !HopRewriteUtils.isSparse(U) && !HopRewriteUtils.isSparse(V))
				|| (HopRewriteUtils.isDense(U) && HopRewriteUtils.isDense(V)) ) {
				V = !HopRewriteUtils.isTransposeOperation(V) ?
					HopRewriteUtils.createTranspose(V) : V.getInput(0);
				hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, 
					OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 0, true, false);
				hnew.setBlocksize(W.getBlocksize());
				hnew.refreshSizeInformation();
				appliedPattern = true;
				LOG.debug("Applied simplifyWeightedDivMM7 (line "+hi.getBeginLine()+")");
			}
		}
		
		//relink new hop into original position
		if( hnew != null ) {
			HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
			hi = hnew;
		}
		
		return hi;
	}