in src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java [2175:2313]
private static Hop simplifyWeightedUnaryMM(Hop parent, Hop hi, int pos) {
Hop hnew = null;
boolean appliedPattern = false;
//Pattern 1) (W*uop(U%*%t(V)))
if( hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(),LOOKUP_VALID_WDIVMM_BINARY)
&& HopRewriteUtils.isEqualSize(hi.getInput(0), hi.getInput(1)) //prevent mv
&& hi.getDim2() > 1 //not applied for vector-vector mult
&& hi.getInput(0).getDataType() == DataType.MATRIX
&& hi.getInput(0).getDim2() > hi.getInput(0).getBlocksize()
&& hi.getInput(1) instanceof UnaryOp
&& HopRewriteUtils.isValidOp(((UnaryOp)hi.getInput(1)).getOp(), LOOKUP_VALID_WUMM_UNARY)
&& hi.getInput(1).getInput(0) instanceof AggBinaryOp
&& HopRewriteUtils.isSingleBlock(hi.getInput(1).getInput(0).getInput(0),true) ) //BLOCKSIZE CONSTRAINT
{
Hop W = hi.getInput(0);
Hop U = hi.getInput(1).getInput(0).getInput(0);
Hop V = hi.getInput(1).getInput(0).getInput(1);
boolean mult = ((BinaryOp)hi).getOp()==OpOp2.MULT;
OpOp1 op = ((UnaryOp)hi.getInput(1)).getOp();
if( !HopRewriteUtils.isTransposeOperation(V) )
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
OpOp4.WUMM, W, U, V, mult, op, null);
hnew.setBlocksize(W.getBlocksize());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedUnaryMM1 (line "+hi.getBeginLine()+")");
}
//Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
if( !appliedPattern
&& hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(), OpOp2.MULT)
&& (HopRewriteUtils.isLiteralOfValue(hi.getInput(0), 2)
|| HopRewriteUtils.isLiteralOfValue(hi.getInput(1), 2)))
{
final Hop nl; // non-literal
if( hi.getInput(0) instanceof LiteralOp ) {
nl = hi.getInput(1);
} else {
nl = hi.getInput(0);
}
if ( HopRewriteUtils.isBinary(nl, OpOp2.MULT)
&& nl.getParent().size()==1 // ensure no foreign parents
&& HopRewriteUtils.isEqualSize(nl.getInput(0), nl.getInput(1)) //prevent mv
&& nl.getDim2() > 1 //not applied for vector-vector mult
&& nl.getInput(0).getDataType() == DataType.MATRIX
&& nl.getInput(0).getDim2() > nl.getInput(0).getBlocksize()
&& HopRewriteUtils.isOuterProductLikeMM(nl.getInput(1))
&& (((AggBinaryOp) nl.getInput(1)).checkMapMultChain() == ChainType.NONE || nl.getInput(1).getInput(1).getDim2() > 1) //no mmchain
&& HopRewriteUtils.isSingleBlock(nl.getInput(1).getInput(0),true) )
{
final Hop W = nl.getInput(0);
final Hop U = nl.getInput(1).getInput(0);
Hop V = nl.getInput(1).getInput(1);
if( !HopRewriteUtils.isTransposeOperation(V) )
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT);
hnew.setBlocksize(W.getBlocksize());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line "+hi.getBeginLine()+")");
}
}
//Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
if( !appliedPattern
&& hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(),LOOKUP_VALID_WDIVMM_BINARY)
&& HopRewriteUtils.isEqualSize(hi.getInput(0), hi.getInput(1)) //prevent mv
&& hi.getDim2() > 1 //not applied for vector-vector mult
&& hi.getInput(0).getDataType() == DataType.MATRIX
&& hi.getInput(0).getDim2() > hi.getInput(0).getBlocksize()
&& hi.getInput(1) instanceof BinaryOp
&& HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput(1)).getOp(), LOOKUP_VALID_WUMM_BINARY) )
{
Hop left = hi.getInput(1).getInput(0);
Hop right = hi.getInput(1).getInput(1);
Hop abop = null;
//pattern 2a) matrix-scalar operations
if( right.getDataType()==DataType.SCALAR && right instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)right)==2 //pow2, mult2
&& left instanceof AggBinaryOp
&& HopRewriteUtils.isSingleBlock(left.getInput(0),true) ) //BLOCKSIZE CONSTRAINT
{
abop = left;
}
//pattern 2b) scalar-matrix operations
else if( left.getDataType()==DataType.SCALAR && left instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)left)==2 //mult2
&& ((BinaryOp)hi.getInput(1)).getOp() == OpOp2.MULT
&& right instanceof AggBinaryOp
&& HopRewriteUtils.isSingleBlock(right.getInput(0),true) ) //BLOCKSIZE CONSTRAINT
{
abop = right;
}
if( abop != null ) {
Hop W = hi.getInput(0);
Hop U = abop.getInput(0);
Hop V = abop.getInput(1);
boolean mult = ((BinaryOp)hi).getOp()==OpOp2.MULT;
OpOp2 op = ((BinaryOp)hi.getInput(1)).getOp();
if( !HopRewriteUtils.isTransposeOperation(V) )
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
OpOp4.WUMM, W, U, V, mult, null, op);
hnew.setBlocksize(W.getBlocksize());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedUnaryMM2 (line "+hi.getBeginLine()+")");
}
}
//relink new hop into original position
if( hnew != null ) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}