private static Hop simplifyWeightedUnaryMM()

in src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java [2175:2313]


	private static Hop simplifyWeightedUnaryMM(Hop parent, Hop hi, int pos) {
		Hop hnew = null;
		boolean appliedPattern = false;
		
		//Pattern 1) (W*uop(U%*%t(V)))
		if( hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(),LOOKUP_VALID_WDIVMM_BINARY)	
			&& HopRewriteUtils.isEqualSize(hi.getInput(0), hi.getInput(1)) //prevent mv
			&& hi.getDim2() > 1 //not applied for vector-vector mult
			&& hi.getInput(0).getDataType() == DataType.MATRIX 
			&& hi.getInput(0).getDim2() > hi.getInput(0).getBlocksize()
			&& hi.getInput(1) instanceof UnaryOp
			&& HopRewriteUtils.isValidOp(((UnaryOp)hi.getInput(1)).getOp(), LOOKUP_VALID_WUMM_UNARY) 
			&& hi.getInput(1).getInput(0) instanceof AggBinaryOp
			&& HopRewriteUtils.isSingleBlock(hi.getInput(1).getInput(0).getInput(0),true) ) //BLOCKSIZE CONSTRAINT			
		{
			Hop W = hi.getInput(0); 
			Hop U = hi.getInput(1).getInput(0).getInput(0);
			Hop V = hi.getInput(1).getInput(0).getInput(1);
			boolean mult = ((BinaryOp)hi).getOp()==OpOp2.MULT;
			OpOp1 op = ((UnaryOp)hi.getInput(1)).getOp();
			
			if( !HopRewriteUtils.isTransposeOperation(V) )
				V = HopRewriteUtils.createTranspose(V);
			else
				V = V.getInput(0);
				
			hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, 
				OpOp4.WUMM, W, U, V, mult, op, null);
			hnew.setBlocksize(W.getBlocksize());
			hnew.refreshSizeInformation();
			
			appliedPattern = true;
			LOG.debug("Applied simplifyWeightedUnaryMM1 (line "+hi.getBeginLine()+")");
		}

		//Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
		if( !appliedPattern
				&& hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(), OpOp2.MULT)
				&& (HopRewriteUtils.isLiteralOfValue(hi.getInput(0), 2)
					|| HopRewriteUtils.isLiteralOfValue(hi.getInput(1), 2)))
		{
			final Hop nl; // non-literal
			if( hi.getInput(0) instanceof LiteralOp ) {
				nl = hi.getInput(1);
			} else {
				nl = hi.getInput(0);
			}

			if (       HopRewriteUtils.isBinary(nl, OpOp2.MULT)
					&& nl.getParent().size()==1 // ensure no foreign parents
					&& HopRewriteUtils.isEqualSize(nl.getInput(0), nl.getInput(1)) //prevent mv
					&& nl.getDim2() > 1 //not applied for vector-vector mult
					&& nl.getInput(0).getDataType() == DataType.MATRIX
					&& nl.getInput(0).getDim2() > nl.getInput(0).getBlocksize()
					&& HopRewriteUtils.isOuterProductLikeMM(nl.getInput(1))
					&& (((AggBinaryOp) nl.getInput(1)).checkMapMultChain() == ChainType.NONE || nl.getInput(1).getInput(1).getDim2() > 1) //no mmchain
					&& HopRewriteUtils.isSingleBlock(nl.getInput(1).getInput(0),true) )
			{
				final Hop W = nl.getInput(0);
				final Hop U = nl.getInput(1).getInput(0);
				Hop V = nl.getInput(1).getInput(1);
				if( !HopRewriteUtils.isTransposeOperation(V) )
					V = HopRewriteUtils.createTranspose(V);
				else
					V = V.getInput(0);

				hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
						OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT);
				hnew.setBlocksize(W.getBlocksize());
				hnew.refreshSizeInformation();

				appliedPattern = true;
				LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line "+hi.getBeginLine()+")");
			}
		}
		
		//Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
		if( !appliedPattern
			&& hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(),LOOKUP_VALID_WDIVMM_BINARY)
			&& HopRewriteUtils.isEqualSize(hi.getInput(0), hi.getInput(1)) //prevent mv
			&& hi.getDim2() > 1 //not applied for vector-vector mult
			&& hi.getInput(0).getDataType() == DataType.MATRIX
			&& hi.getInput(0).getDim2() > hi.getInput(0).getBlocksize()
			&& hi.getInput(1) instanceof BinaryOp
			&& HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput(1)).getOp(), LOOKUP_VALID_WUMM_BINARY) )
		{
			Hop left = hi.getInput(1).getInput(0);
			Hop right = hi.getInput(1).getInput(1);
			Hop abop = null;
			
			//pattern 2a) matrix-scalar operations
			if( right.getDataType()==DataType.SCALAR && right instanceof LiteralOp
				&& HopRewriteUtils.getDoubleValue((LiteralOp)right)==2 //pow2, mult2
				&& left instanceof AggBinaryOp
				&& HopRewriteUtils.isSingleBlock(left.getInput(0),true) ) //BLOCKSIZE CONSTRAINT
			{
				abop = left;
			}
			//pattern 2b) scalar-matrix operations
			else if( left.getDataType()==DataType.SCALAR && left instanceof LiteralOp 
				&& HopRewriteUtils.getDoubleValue((LiteralOp)left)==2 //mult2
				&& ((BinaryOp)hi.getInput(1)).getOp() == OpOp2.MULT
				&& right instanceof AggBinaryOp
				&& HopRewriteUtils.isSingleBlock(right.getInput(0),true) ) //BLOCKSIZE CONSTRAINT
			{
				abop = right;
			}
			
			if( abop != null ) {
				Hop W = hi.getInput(0); 
				Hop U = abop.getInput(0);
				Hop V = abop.getInput(1);
				boolean mult = ((BinaryOp)hi).getOp()==OpOp2.MULT;
				OpOp2 op = ((BinaryOp)hi.getInput(1)).getOp();
				
				if( !HopRewriteUtils.isTransposeOperation(V) )
					V = HopRewriteUtils.createTranspose(V);
				else
					V = V.getInput(0);
					
				hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, 
						  OpOp4.WUMM, W, U, V, mult, null, op);
				hnew.setBlocksize(W.getBlocksize());
				hnew.refreshSizeInformation();
				
				appliedPattern = true;
				LOG.debug("Applied simplifyWeightedUnaryMM2 (line "+hi.getBeginLine()+")");	
			}
		}
		
		
		//relink new hop into original position
		if( hnew != null ) {
			HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
			hi = hnew;
		}
		
		return hi;
	}