private static Hop simplifyWeightedSquaredLoss()

in src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java [1457:1649]


	private static Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) 
	{
		//NOTE: there might be also a general simplification without custom operator
		//via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
		Hop hnew = null;
		boolean appliedPattern = false;
		
		if( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.RowCol) //all patterns rooted by sum()
			&& hi.getInput(0) instanceof BinaryOp  //all patterns subrooted by binary op
			&& hi.getInput(0).getDim2() > 1  )     //not applied for vector-vector mult
		{
			BinaryOp bop = (BinaryOp) hi.getInput(0);
			
			//Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
			//alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
			if( bop.getOp()==OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput(1), OpOp2.POW)
				&& bop.getInput(0).getDataType()==DataType.MATRIX
				&& HopRewriteUtils.isEqualSize(bop.getInput(0), bop.getInput(1)) //prevent mv
				&& HopRewriteUtils.isLiteralOfValue(bop.getInput(1).getInput(1), 2) )
			{
				Hop W = bop.getInput(0);
				Hop tmp = bop.getInput(1).getInput(0); //(X - U %*% t(V))
				
				if( HopRewriteUtils.isBinary(tmp, OpOp2.MINUS)
					&& HopRewriteUtils.isEqualSize(tmp.getInput(0), tmp.getInput(1)) //prevent mv
					&& tmp.getInput(0).getDataType() == DataType.MATRIX )
				{
					//a) sum (W * (X - U %*% t(V)) ^ 2)
					int uvIndex = -1;
					if( tmp.getInput(1) instanceof AggBinaryOp  //ba gurantees matrices
						&& HopRewriteUtils.isSingleBlock(tmp.getInput(1).getInput(0),true)) { //BLOCKSIZE CONSTRAINT
						uvIndex = 1;
					}
					//b) sum (W * (U %*% t(V) - X) ^ 2)
					else if(tmp.getInput(0) instanceof AggBinaryOp  //ba gurantees matrices
						&& HopRewriteUtils.isSingleBlock(tmp.getInput(0).getInput(0),true)) { //BLOCKSIZE CONSTRAINT
						uvIndex = 0;
					}

					if( uvIndex >= 0 ) { //rewrite match
						Hop X = tmp.getInput().get((uvIndex==0)?1:0); 
						Hop U = tmp.getInput().get(uvIndex).getInput(0);
						Hop V = tmp.getInput().get(uvIndex).getInput(1);
						V = !HopRewriteUtils.isTransposeOperation(V) ?
							HopRewriteUtils.createTranspose(V) : V.getInput(0);

						//handle special case of post_nz
						if( HopRewriteUtils.isNonZeroIndicator(W, X) ){
							W = new LiteralOp(1);
						}

						//construct quaternary hop
						hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR,
							ValueType.FP64, OpOp4.WSLOSS, X, U, V, W, true);
						HopRewriteUtils.setOutputParametersForScalar(hnew);

						appliedPattern = true;
						LOG.debug("Applied simplifyWeightedSquaredLoss1"+uvIndex+" (line "+hi.getBeginLine()+")");
					}
				}
			}
			
			//Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
			//alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
			if( !appliedPattern
				&& bop.getOp()==OpOp2.POW && HopRewriteUtils.isLiteralOfValue(bop.getInput(1), 2)
				&& HopRewriteUtils.isBinary(bop.getInput(0), OpOp2.MINUS)
				&& HopRewriteUtils.isEqualMatrixSize((BinaryOp)bop.getInput(0)))
			{
				Hop lleft = bop.getInput(0).getInput(0); 
				Hop lright = bop.getInput(0).getInput(1); 

				//a) sum ((X - W * (U %*% t(V))) ^ 2)
				int wuvIndex = -1;
				if( lright instanceof BinaryOp && lright.getInput(1) instanceof AggBinaryOp ){
					wuvIndex = 1;
				}
				//b) sum ((W * (U %*% t(V)) - X) ^ 2)
				else if( lleft instanceof BinaryOp && lleft.getInput(1) instanceof AggBinaryOp ){
					wuvIndex = 0;
				}

				if( wuvIndex >= 0 ) //rewrite match
				{
					Hop X = bop.getInput(0).getInput().get((wuvIndex==0)?1:0);
					Hop tmp = bop.getInput(0).getInput().get(wuvIndex); //(W * (U %*% t(V)))
					
					if( ((BinaryOp)tmp).getOp()==OpOp2.MULT
						&& tmp.getInput(0).getDataType() == DataType.MATRIX
						&& HopRewriteUtils.isEqualSize(tmp.getInput(0), tmp.getInput(1)) //prevent mv
						&& HopRewriteUtils.isSingleBlock(tmp.getInput(1).getInput(0),true)) //BLOCKSIZE CONSTRAINT
					{
						Hop W = tmp.getInput(0);
						Hop U = tmp.getInput(1).getInput(0);
						Hop V = tmp.getInput(1).getInput(1);
						V = !HopRewriteUtils.isTransposeOperation(V) ?
							HopRewriteUtils.createTranspose(V) : V.getInput(0);
						hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR,
							ValueType.FP64, OpOp4.WSLOSS, X, U, V, W, false);
						HopRewriteUtils.setOutputParametersForScalar(hnew);
						appliedPattern = true;
						LOG.debug("Applied simplifyWeightedSquaredLoss2"+wuvIndex+" (line "+hi.getBeginLine()+")");	
					}
				}
			}

			//Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
			//alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
			if( !appliedPattern
				&& bop.getOp()==OpOp2.POW && HopRewriteUtils.isLiteralOfValue(bop.getInput(1), 2)
				&& HopRewriteUtils.isBinary(bop.getInput(0), OpOp2.MINUS) 
				&& HopRewriteUtils.isEqualMatrixSize((BinaryOp)bop.getInput(0))) //prevent mv
			{
				Hop lleft = bop.getInput(0).getInput(0);
				Hop lright = bop.getInput(0).getInput(1);

				//a) sum ((X - (U %*% t(V))) ^ 2)
				int uvIndex = -1;
				if( lright instanceof AggBinaryOp //ba guarantees matrices
					&& HopRewriteUtils.isSingleBlock(lright.getInput(0),true) ) { //BLOCKSIZE CONSTRAINT
					uvIndex = 1;
				}
				//b) sum (((U %*% t(V)) - X) ^ 2)
				else if( lleft instanceof AggBinaryOp //ba guarantees matrices
						&& HopRewriteUtils.isSingleBlock(lleft.getInput(0),true) ) { //BLOCKSIZE CONSTRAINT
					uvIndex = 0;
				}

				if( uvIndex >= 0 ) { //rewrite match
					Hop X = bop.getInput(0).getInput().get((uvIndex==0)?1:0);
					Hop tmp = bop.getInput(0).getInput().get(uvIndex); //(U %*% t(V))
					Hop W = new LiteralOp(1); //no weighting 
					Hop U = tmp.getInput(0);
					Hop V = tmp.getInput(1);
					V = !HopRewriteUtils.isTransposeOperation(V) ?
						HopRewriteUtils.createTranspose(V) : V.getInput(0);
					hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR,
						ValueType.FP64, OpOp4.WSLOSS, X, U, V, W, false);
					HopRewriteUtils.setOutputParametersForScalar(hnew);
					appliedPattern = true;
					
					LOG.debug("Applied simplifyWeightedSquaredLoss3"+uvIndex+" (line "+hi.getBeginLine()+")");
				}
			}
		}
		
		//Pattern 4) sumSq (X - U %*% t(V)) (no weighting)
		//alternative pattern: sumSq (U %*% t(V) - X)
		if( !appliedPattern
			&& HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM_SQ, Direction.RowCol)
			&& HopRewriteUtils.isBinary(hi.getInput(0), OpOp2.MINUS) 
			&& HopRewriteUtils.isEqualMatrixSize((BinaryOp)hi.getInput(0))) //prevent mv
		{
			Hop lleft = hi.getInput(0).getInput(0);
			Hop lright = hi.getInput(0).getInput(1);

			//a) sumSq (X - U %*% t(V))
			int uvIndex = -1;
			if( lright instanceof AggBinaryOp //ba guarantees matrices
				&& HopRewriteUtils.isSingleBlock(lright.getInput(0),true) ) { //BLOCKSIZE CONSTRAINT
				uvIndex = 1;
			}
			//b) sumSq (U %*% t(V) - X)
			else if( lleft instanceof AggBinaryOp //ba guarantees matrices
					&& HopRewriteUtils.isSingleBlock(lleft.getInput(0),true) ) { //BLOCKSIZE CONSTRAINT
				uvIndex = 0;
			}

			if( uvIndex >= 0 ) { //rewrite match
				Hop X = hi.getInput(0).getInput().get((uvIndex==0)?1:0);
				Hop tmp = hi.getInput(0).getInput().get(uvIndex); //(U %*% t(V))
				Hop W = new LiteralOp(1); //no weighting 
				Hop U = tmp.getInput(0);
				Hop V = tmp.getInput(1);
				V = !HopRewriteUtils.isTransposeOperation(V) ?
					HopRewriteUtils.createTranspose(V) : V.getInput(0);
				hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR,
					ValueType.FP64, OpOp4.WSLOSS, X, U, V, W, false);
				HopRewriteUtils.setOutputParametersForScalar(hnew);
				appliedPattern = true;
				
				LOG.debug("Applied simplifyWeightedSquaredLoss4"+uvIndex+" (line "+hi.getBeginLine()+")");
			}
		}
		
		//relink new hop into original position
		if( hnew != null ) {
			HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
			hi = hnew;
		}
		
		return hi;
	}