in src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java [1457:1649]
private static Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos)
{
//NOTE: there might be also a general simplification without custom operator
//via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
Hop hnew = null;
boolean appliedPattern = false;
if( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.RowCol) //all patterns rooted by sum()
&& hi.getInput(0) instanceof BinaryOp //all patterns subrooted by binary op
&& hi.getInput(0).getDim2() > 1 ) //not applied for vector-vector mult
{
BinaryOp bop = (BinaryOp) hi.getInput(0);
//Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
//alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
if( bop.getOp()==OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput(1), OpOp2.POW)
&& bop.getInput(0).getDataType()==DataType.MATRIX
&& HopRewriteUtils.isEqualSize(bop.getInput(0), bop.getInput(1)) //prevent mv
&& HopRewriteUtils.isLiteralOfValue(bop.getInput(1).getInput(1), 2) )
{
Hop W = bop.getInput(0);
Hop tmp = bop.getInput(1).getInput(0); //(X - U %*% t(V))
if( HopRewriteUtils.isBinary(tmp, OpOp2.MINUS)
&& HopRewriteUtils.isEqualSize(tmp.getInput(0), tmp.getInput(1)) //prevent mv
&& tmp.getInput(0).getDataType() == DataType.MATRIX )
{
//a) sum (W * (X - U %*% t(V)) ^ 2)
int uvIndex = -1;
if( tmp.getInput(1) instanceof AggBinaryOp //ba gurantees matrices
&& HopRewriteUtils.isSingleBlock(tmp.getInput(1).getInput(0),true)) { //BLOCKSIZE CONSTRAINT
uvIndex = 1;
}
//b) sum (W * (U %*% t(V) - X) ^ 2)
else if(tmp.getInput(0) instanceof AggBinaryOp //ba gurantees matrices
&& HopRewriteUtils.isSingleBlock(tmp.getInput(0).getInput(0),true)) { //BLOCKSIZE CONSTRAINT
uvIndex = 0;
}
if( uvIndex >= 0 ) { //rewrite match
Hop X = tmp.getInput().get((uvIndex==0)?1:0);
Hop U = tmp.getInput().get(uvIndex).getInput(0);
Hop V = tmp.getInput().get(uvIndex).getInput(1);
V = !HopRewriteUtils.isTransposeOperation(V) ?
HopRewriteUtils.createTranspose(V) : V.getInput(0);
//handle special case of post_nz
if( HopRewriteUtils.isNonZeroIndicator(W, X) ){
W = new LiteralOp(1);
}
//construct quaternary hop
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR,
ValueType.FP64, OpOp4.WSLOSS, X, U, V, W, true);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss1"+uvIndex+" (line "+hi.getBeginLine()+")");
}
}
}
//Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
//alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
if( !appliedPattern
&& bop.getOp()==OpOp2.POW && HopRewriteUtils.isLiteralOfValue(bop.getInput(1), 2)
&& HopRewriteUtils.isBinary(bop.getInput(0), OpOp2.MINUS)
&& HopRewriteUtils.isEqualMatrixSize((BinaryOp)bop.getInput(0)))
{
Hop lleft = bop.getInput(0).getInput(0);
Hop lright = bop.getInput(0).getInput(1);
//a) sum ((X - W * (U %*% t(V))) ^ 2)
int wuvIndex = -1;
if( lright instanceof BinaryOp && lright.getInput(1) instanceof AggBinaryOp ){
wuvIndex = 1;
}
//b) sum ((W * (U %*% t(V)) - X) ^ 2)
else if( lleft instanceof BinaryOp && lleft.getInput(1) instanceof AggBinaryOp ){
wuvIndex = 0;
}
if( wuvIndex >= 0 ) //rewrite match
{
Hop X = bop.getInput(0).getInput().get((wuvIndex==0)?1:0);
Hop tmp = bop.getInput(0).getInput().get(wuvIndex); //(W * (U %*% t(V)))
if( ((BinaryOp)tmp).getOp()==OpOp2.MULT
&& tmp.getInput(0).getDataType() == DataType.MATRIX
&& HopRewriteUtils.isEqualSize(tmp.getInput(0), tmp.getInput(1)) //prevent mv
&& HopRewriteUtils.isSingleBlock(tmp.getInput(1).getInput(0),true)) //BLOCKSIZE CONSTRAINT
{
Hop W = tmp.getInput(0);
Hop U = tmp.getInput(1).getInput(0);
Hop V = tmp.getInput(1).getInput(1);
V = !HopRewriteUtils.isTransposeOperation(V) ?
HopRewriteUtils.createTranspose(V) : V.getInput(0);
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR,
ValueType.FP64, OpOp4.WSLOSS, X, U, V, W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss2"+wuvIndex+" (line "+hi.getBeginLine()+")");
}
}
}
//Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
//alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
if( !appliedPattern
&& bop.getOp()==OpOp2.POW && HopRewriteUtils.isLiteralOfValue(bop.getInput(1), 2)
&& HopRewriteUtils.isBinary(bop.getInput(0), OpOp2.MINUS)
&& HopRewriteUtils.isEqualMatrixSize((BinaryOp)bop.getInput(0))) //prevent mv
{
Hop lleft = bop.getInput(0).getInput(0);
Hop lright = bop.getInput(0).getInput(1);
//a) sum ((X - (U %*% t(V))) ^ 2)
int uvIndex = -1;
if( lright instanceof AggBinaryOp //ba guarantees matrices
&& HopRewriteUtils.isSingleBlock(lright.getInput(0),true) ) { //BLOCKSIZE CONSTRAINT
uvIndex = 1;
}
//b) sum (((U %*% t(V)) - X) ^ 2)
else if( lleft instanceof AggBinaryOp //ba guarantees matrices
&& HopRewriteUtils.isSingleBlock(lleft.getInput(0),true) ) { //BLOCKSIZE CONSTRAINT
uvIndex = 0;
}
if( uvIndex >= 0 ) { //rewrite match
Hop X = bop.getInput(0).getInput().get((uvIndex==0)?1:0);
Hop tmp = bop.getInput(0).getInput().get(uvIndex); //(U %*% t(V))
Hop W = new LiteralOp(1); //no weighting
Hop U = tmp.getInput(0);
Hop V = tmp.getInput(1);
V = !HopRewriteUtils.isTransposeOperation(V) ?
HopRewriteUtils.createTranspose(V) : V.getInput(0);
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR,
ValueType.FP64, OpOp4.WSLOSS, X, U, V, W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss3"+uvIndex+" (line "+hi.getBeginLine()+")");
}
}
}
//Pattern 4) sumSq (X - U %*% t(V)) (no weighting)
//alternative pattern: sumSq (U %*% t(V) - X)
if( !appliedPattern
&& HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM_SQ, Direction.RowCol)
&& HopRewriteUtils.isBinary(hi.getInput(0), OpOp2.MINUS)
&& HopRewriteUtils.isEqualMatrixSize((BinaryOp)hi.getInput(0))) //prevent mv
{
Hop lleft = hi.getInput(0).getInput(0);
Hop lright = hi.getInput(0).getInput(1);
//a) sumSq (X - U %*% t(V))
int uvIndex = -1;
if( lright instanceof AggBinaryOp //ba guarantees matrices
&& HopRewriteUtils.isSingleBlock(lright.getInput(0),true) ) { //BLOCKSIZE CONSTRAINT
uvIndex = 1;
}
//b) sumSq (U %*% t(V) - X)
else if( lleft instanceof AggBinaryOp //ba guarantees matrices
&& HopRewriteUtils.isSingleBlock(lleft.getInput(0),true) ) { //BLOCKSIZE CONSTRAINT
uvIndex = 0;
}
if( uvIndex >= 0 ) { //rewrite match
Hop X = hi.getInput(0).getInput().get((uvIndex==0)?1:0);
Hop tmp = hi.getInput(0).getInput().get(uvIndex); //(U %*% t(V))
Hop W = new LiteralOp(1); //no weighting
Hop U = tmp.getInput(0);
Hop V = tmp.getInput(1);
V = !HopRewriteUtils.isTransposeOperation(V) ?
HopRewriteUtils.createTranspose(V) : V.getInput(0);
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR,
ValueType.FP64, OpOp4.WSLOSS, X, U, V, W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss4"+uvIndex+" (line "+hi.getBeginLine()+")");
}
}
//relink new hop into original position
if( hnew != null ) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}