in src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java [1786:2097]
private static Hop simplifyWeightedDivMM(Hop parent, Hop hi, int pos) {
Hop hnew = null;
boolean appliedPattern = false;
//left/right patterns rooted by 'ab - b(div)' or 'ab - b(mult)'
//note: we do not rewrite t(X)%*%(w*(X%*%v)) where w and v are vectors (see mmchain ops)
if( HopRewriteUtils.isMatrixMultiply(hi)
&& (hi.getInput(0) instanceof BinaryOp
&& HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput(0)).getOp(), LOOKUP_VALID_WDIVMM_BINARY)
|| hi.getInput(1) instanceof BinaryOp
&& hi.getDim2() > 1 //not applied for vector-vector mult
&& HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput(1)).getOp(), LOOKUP_VALID_WDIVMM_BINARY)) )
{
Hop left = hi.getInput(0);
Hop right = hi.getInput(1);
//Pattern 1) t(U) %*% (W/(U%*%t(V)))
//alternative pattern: t(U) %*% (W*(U%*%t(V)))
if( right instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)right).getOp(),LOOKUP_VALID_WDIVMM_BINARY)
&& HopRewriteUtils.isEqualSize(right.getInput(0), right.getInput(1)) //prevent mv
&& HopRewriteUtils.isOuterProductLikeMM(right.getInput(1))
&& HopRewriteUtils.isSingleBlock(right.getInput(1).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = right.getInput(0);
Hop U = right.getInput(1).getInput(0);
Hop V = right.getInput(1).getInput(1);
if( HopRewriteUtils.isTransposeOfItself(left, U) )
{
if( !HopRewriteUtils.isTransposeOperation(V) )
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput(0);
boolean mult = ((BinaryOp)right).getOp() == OpOp2.MULT;
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 1, mult, false);
hnew.setBlocksize(W.getBlocksize());
hnew.refreshSizeInformation();
//add output transpose for efficient target indexing (redundant t() removed by other rewrites)
hnew = HopRewriteUtils.createTranspose(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM1 (line "+hi.getBeginLine()+")");
}
}
//Pattern 1e) t(U) %*% (W/(U%*%t(V) + x))
if( !appliedPattern
&& HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[1]) //DIV
&& HopRewriteUtils.isEqualSize(right.getInput(0), right.getInput(1)) //prevent mv
&& HopRewriteUtils.isBinary(right.getInput(1), OpOp2.PLUS)
&& right.getInput(1).getInput(1).getDataType() == DataType.SCALAR
&& HopRewriteUtils.isOuterProductLikeMM(right.getInput(1).getInput(0))
&& HopRewriteUtils.isSingleBlock(right.getInput(1).getInput(0).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = right.getInput(0);
Hop U = right.getInput(1).getInput(0).getInput(0);
Hop V = right.getInput(1).getInput(0).getInput(1);
Hop X = right.getInput(1).getInput(1);
if( HopRewriteUtils.isTransposeOfItself(left, U) )
{
if( !HopRewriteUtils.isTransposeOperation(V) )
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
OpOp4.WDIVMM, W, U, V, X, 3, false, false); // 3=>DIV_LEFT_EPS
hnew.setBlocksize(W.getBlocksize());
hnew.refreshSizeInformation();
//add output transpose for efficient target indexing (redundant t() removed by other rewrites)
hnew = HopRewriteUtils.createTranspose(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM1e (line "+hi.getBeginLine()+")");
}
}
//Pattern 2) (W/(U%*%t(V))) %*% V
//alternative pattern: (W*(U%*%t(V))) %*% V
if( !appliedPattern
&& left instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)left).getOp(), LOOKUP_VALID_WDIVMM_BINARY)
&& HopRewriteUtils.isEqualSize(left.getInput(0), left.getInput(1)) //prevent mv
&& HopRewriteUtils.isOuterProductLikeMM(left.getInput(1))
&& HopRewriteUtils.isSingleBlock(left.getInput(1).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = left.getInput(0);
Hop U = left.getInput(1).getInput(0);
Hop V = left.getInput(1).getInput(1);
if( HopRewriteUtils.isTransposeOfItself(right, V) )
{
if( !HopRewriteUtils.isTransposeOperation(V) )
V = right;
else
V = V.getInput(0);
boolean mult = ((BinaryOp)left).getOp() == OpOp2.MULT;
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 2, mult, false);
hnew.setBlocksize(W.getBlocksize());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM2 (line "+hi.getBeginLine()+")");
}
}
//Pattern 2e) (W/(U%*%t(V) + x)) %*% V
if( !appliedPattern
&& HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[1]) //DIV
&& HopRewriteUtils.isEqualSize(left.getInput(0), left.getInput(1)) //prevent mv
&& HopRewriteUtils.isBinary(left.getInput(1), OpOp2.PLUS)
&& left.getInput(1).getInput(1).getDataType() == DataType.SCALAR
&& HopRewriteUtils.isOuterProductLikeMM(left.getInput(1).getInput(0))
&& HopRewriteUtils.isSingleBlock(left.getInput(1).getInput(0).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = left.getInput(0);
Hop U = left.getInput(1).getInput(0).getInput(0);
Hop V = left.getInput(1).getInput(0).getInput(1);
Hop X = left.getInput(1).getInput(1);
if( HopRewriteUtils.isTransposeOfItself(right, V) )
{
if( !HopRewriteUtils.isTransposeOperation(V) )
V = right;
else
V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
OpOp4.WDIVMM, W, U, V, X, 4, false, false); // 4=>DIV_RIGHT_EPS
hnew.setBlocksize(W.getBlocksize());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM2e (line "+hi.getBeginLine()+")");
}
}
//Pattern 3) t(U) %*% ((X!=0)*(U%*%t(V)-X))
if( !appliedPattern
&& HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
&& HopRewriteUtils.isBinary(right.getInput(1), OpOp2.MINUS)
&& HopRewriteUtils.isOuterProductLikeMM(right.getInput(1).getInput(0))
&& right.getInput(1).getInput(1).getDataType() == DataType.MATRIX
&& HopRewriteUtils.isSingleBlock(right.getInput(1).getInput(0).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = right.getInput(0);
Hop U = right.getInput(1).getInput(0).getInput(0);
Hop V = right.getInput(1).getInput(0).getInput(1);
Hop X = right.getInput(1).getInput(1);
if( HopRewriteUtils.isNonZeroIndicator(W, X) //W-X constraint
&& HopRewriteUtils.isTransposeOfItself(left, U) ) //t(U)-U constraint
{
if( !HopRewriteUtils.isTransposeOperation(V) )
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 1, true, true);
hnew.setBlocksize(W.getBlocksize());
hnew.refreshSizeInformation();
//add output transpose for efficient target indexing (redundant t() removed by other rewrites)
hnew = HopRewriteUtils.createTranspose(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM3 (line "+hi.getBeginLine()+")");
}
}
//Pattern 4) ((X!=0)*(U%*%t(V)-X)) %*% V
if( !appliedPattern
&& HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
&& HopRewriteUtils.isBinary(left.getInput(1), OpOp2.MINUS)
&& HopRewriteUtils.isOuterProductLikeMM(left.getInput(1).getInput(0))
&& left.getInput(1).getInput(1).getDataType() == DataType.MATRIX
&& HopRewriteUtils.isSingleBlock(left.getInput(1).getInput(0).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = left.getInput(0);
Hop U = left.getInput(1).getInput(0).getInput(0);
Hop V = left.getInput(1).getInput(0).getInput(1);
Hop X = left.getInput(1).getInput(1);
if( HopRewriteUtils.isNonZeroIndicator(W, X) //W-X constraint
&& HopRewriteUtils.isTransposeOfItself(right, V) ) //V-t(V) constraint
{
if( !HopRewriteUtils.isTransposeOperation(V) )
V = right;
else
V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 2, true, true);
hnew.setBlocksize(W.getBlocksize());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM4 (line "+hi.getBeginLine()+")");
}
}
//Pattern 5) t(U) %*% (W*(U%*%t(V)-X))
if( !appliedPattern
&& HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
&& HopRewriteUtils.isBinary(right.getInput(1), OpOp2.MINUS)
&& HopRewriteUtils.isOuterProductLikeMM(right.getInput(1).getInput(0))
&& right.getInput(1).getInput(1).getDataType() == DataType.MATRIX
&& HopRewriteUtils.isSingleBlock(right.getInput(1).getInput(0).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = right.getInput(0);
Hop U = right.getInput(1).getInput(0).getInput(0);
Hop V = right.getInput(1).getInput(0).getInput(1);
Hop X = right.getInput(1).getInput(1);
if( HopRewriteUtils.isTransposeOfItself(left, U) ) //t(U)-U constraint
{
if( !HopRewriteUtils.isTransposeOperation(V) )
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput(0);
//note: x and w exchanged compared to patterns 1-4, 7
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
OpOp4.WDIVMM, W, U, V, X, 1, true, true);
hnew.setBlocksize(W.getBlocksize());
hnew.refreshSizeInformation();
//add output transpose for efficient target indexing (redundant t() removed by other rewrites)
hnew = HopRewriteUtils.createTranspose(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM5 (line "+hi.getBeginLine()+")");
}
}
//Pattern 6) (W*(U%*%t(V)-X)) %*% V
if( !appliedPattern
&& HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
&& HopRewriteUtils.isBinary(left.getInput(1), OpOp2.MINUS)
&& HopRewriteUtils.isOuterProductLikeMM(left.getInput(1).getInput(0))
&& left.getInput(1).getInput(1).getDataType() == DataType.MATRIX
&& HopRewriteUtils.isSingleBlock(left.getInput(1).getInput(0).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = left.getInput(0);
Hop U = left.getInput(1).getInput(0).getInput(0);
Hop V = left.getInput(1).getInput(0).getInput(1);
Hop X = left.getInput(1).getInput(1);
if( HopRewriteUtils.isTransposeOfItself(right, V) ) //V-t(V) constraint
{
if( !HopRewriteUtils.isTransposeOperation(V) )
V = right;
else
V = V.getInput(0);
//note: x and w exchanged compared to patterns 1-4, 7
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
OpOp4.WDIVMM, W, U, V, X, 2, true, true);
hnew.setBlocksize(W.getBlocksize());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM6 (line "+hi.getBeginLine()+")");
}
}
}
//Pattern 7) (W*(U%*%t(V)))
if( !appliedPattern
&& HopRewriteUtils.isBinary(hi, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
&& HopRewriteUtils.isEqualSize(hi.getInput(0), hi.getInput(1)) //prevent mv
&& hi.getDim2() > 1 //not applied for vector-vector mult
&& hi.getInput(0).getDataType() == DataType.MATRIX
&& hi.getInput(0).getDim2() > hi.getInput(0).getBlocksize()
&& HopRewriteUtils.isOuterProductLikeMM(hi.getInput(1))
&& (((AggBinaryOp) hi.getInput(1)).checkMapMultChain() == ChainType.NONE || hi.getInput(1).getInput(1).getDim2() > 1) //no mmchain
&& HopRewriteUtils.isSingleBlock(hi.getInput(1).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = hi.getInput(0);
Hop U = hi.getInput(1).getInput(0);
Hop V = hi.getInput(1).getInput(1);
//for this basic pattern, we're more conservative and only apply wdivmm if
//W is sparse and U/V unknown or dense; or if U/V are dense
if( (HopRewriteUtils.isSparse(W) && !HopRewriteUtils.isSparse(U) && !HopRewriteUtils.isSparse(V))
|| (HopRewriteUtils.isDense(U) && HopRewriteUtils.isDense(V)) ) {
V = !HopRewriteUtils.isTransposeOperation(V) ?
HopRewriteUtils.createTranspose(V) : V.getInput(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 0, true, false);
hnew.setBlocksize(W.getBlocksize());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM7 (line "+hi.getBeginLine()+")");
}
}
//relink new hop into original position
if( hnew != null ) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}