in src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java [278:553]
private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals)
{
//memoization for common subexpression elimination and to avoid redundant work
if( tmp.containsKey(hop.getHopID()) )
return;
//recursively process required childs
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.ROW, TemplateType.CELL);
for( int i=0; i<hop.getInput().size(); i++ ) {
Hop c = hop.getInput().get(i);
if( me!=null && me.isPlanRef(i) )
rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
//construct cnode for current hop
CNode out = null;
if(hop instanceof AggUnaryOp)
{
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
if( ((AggUnaryOp)hop).getDirection().isRow() && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG) ) {
if(hop.getInput().get(0).getDim2()==1)
out = (cdata1.getDataType()==DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1,UnaryType.LOOKUP_R);
else {
String opcode = "ROW_"+((AggUnaryOp)hop).getOp().name().toUpperCase()+"S";
out = new CNodeUnary(cdata1, UnaryType.valueOf(opcode));
if( cdata1 instanceof CNodeData && !inHops2.containsKey("X") )
inHops2.put("X", hop.getInput().get(0));
}
}
else if ( HopRewriteUtils.isAggUnaryOp(hop, AggOp.SUM, AggOp.MEAN)
&& ((AggUnaryOp)hop).getDirection().isCol() ) { //closes row template
//vector add without temporary copy
if( cdata1 instanceof CNodeBinary && ((CNodeBinary)cdata1).getType().isVectorScalarPrimitive() )
out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1),
((CNodeBinary)cdata1).getType().getVectorAddPrimitive());
else
out = cdata1;
}
else if( ((AggUnaryOp)hop).getDirection() == Direction.RowCol && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) {
out = (cdata1.getDataType().isMatrix()) ?
new CNodeUnary(cdata1, UnaryType.ROW_SUMS) : cdata1;
}
}
else if(hop instanceof AggBinaryOp)
{
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
if( HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)) )
{
//correct input under transpose
cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
inHops.remove(hop.getInput().get(0));
if( cdata1 instanceof CNodeData )
inHops.add(hop.getInput().get(0).getInput().get(0));
//note: vectorMultAdd applicable to vector-scalar, and vector-vector
if( hop.getInput().get(1).getDim2() == 1 )
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
else {
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_OUTERMULT_ADD);
if( !inHops2.containsKey("B1") ) { //incl modification of X for consistency
if( cdata1 instanceof CNodeData )
inHops2.put("X", hop.getInput().get(0).getInput().get(0));
inHops2.put("B1", hop.getInput().get(1));
}
}
if( !inHops2.containsKey("X") )
inHops2.put("X", hop.getInput().get(0).getInput().get(0));
}
else
{
if(hop.getInput().get(0).getDim2()==1 && hop.getInput().get(1).getDim2()==1)
out = new CNodeBinary((cdata1.getDataType()==DataType.SCALAR)? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0),
(cdata2.getDataType()==DataType.SCALAR)? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT);
else if( hop.getInput().get(1).getDim2()==1 ) {
out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
inHops2.put("X", hop.getInput().get(0));
}
else {
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MATRIXMULT);
inHops2.put("X", hop.getInput().get(0));
inHops2.put("B1", hop.getInput().get(1));
}
}
}
else if( HopRewriteUtils.isDataGenOp(hop, OpOpDG.SEQ) ) {
CNodeData from = TemplateUtils.getLiteral(tmp.get(((DataGenOp)hop).getParam(Statement.SEQ_FROM).getHopID()));
CNodeData to = TemplateUtils.getLiteral(tmp.get(((DataGenOp)hop).getParam(Statement.SEQ_TO).getHopID()));
CNodeData incr = TemplateUtils.getLiteral(tmp.get(((DataGenOp)hop).getParam(Statement.SEQ_INCR).getHopID()));
if( Double.parseDouble(from.getVarname()) > Double.parseDouble(to.getVarname())
&& Double.parseDouble(incr.getVarname()) > 0 ) {
incr = TemplateUtils.createCNodeData(new LiteralOp("-"+incr.getVarname()), true);
}
out = new CNodeBinary(from, incr, BinType.SEQ_RIX);
}
else if( HopRewriteUtils.isTransposeOperation(hop) ) {
out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()),
hop, tmp, compileLiterals);
if( out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)) )
inHops.add(hop.getInput().get(0));
}
else if(hop instanceof UnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
// if one input is a matrix then we need to do vector by scalar operations
if(hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1
|| (!hop.dimsKnown() && cdata1.getDataType()==DataType.MATRIX ) )
{
if( HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY) ) {
String opname = "VECT_"+((UnaryOp)hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(opname));
if( cdata1 instanceof CNodeData && !inHops2.containsKey("X") )
inHops2.put("X", hop.getInput().get(0));
}
else
throw new RuntimeException("Unsupported unary matrix "
+ "operation: " + ((UnaryOp)hop).getOp().name());
}
else //general scalar case
{
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
String primitiveOpName = ((UnaryOp)hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
}
}
else if(HopRewriteUtils.isBinary(hop, OpOp2.CBIND)) {
//special case for cbind with zeros
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = null;
if( HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1)) ) {
cdata2 = TemplateUtils.createCNodeData(HopRewriteUtils
.getDataGenOpConstantValue(hop.getInput().get(1)), true);
inHops.remove(hop.getInput().get(1)); //rm 0-matrix
}
else {
cdata2 = tmp.get(hop.getInput().get(1).getHopID());
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1), true);
}
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_CBIND);
if( cdata1 instanceof CNodeData && !inHops2.containsKey("X") )
inHops2.put("X", hop.getInput().get(0));
}
else if(hop instanceof BinaryOp)
{
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
// if one input is a matrix then we need to do vector by scalar operations
if( (hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1)
|| (hop.getInput().get(1).getDim1() >= 1 && hop.getInput().get(1).getDim2() > 1)
|| (!(hop.dimsKnown() && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown())
&& (hop.getDim2() != 1) //not a known vector output
&& (cdata1.getDataType().isMatrix() || cdata2.getDataType().isMatrix())))
{
if( HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY) ) {
if( TemplateUtils.isColVector(cdata1) )
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
if( TemplateUtils.isColVector(cdata2) )
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = getVectorBinary(cdata1, cdata2, ((BinaryOp)hop).getOp().name());
if( cdata1 instanceof CNodeData && !inHops2.containsKey("X")
&& !(cdata1.getDataType()==DataType.SCALAR) ) {
inHops2.put("X", hop.getInput().get(0));
}
}
else
throw new RuntimeException("Unsupported binary matrix "
+ "operation: " + ((BinaryOp)hop).getOp().name());
}
else //one input is a vector/scalar other is a scalar
{
String primitiveOpName = ((BinaryOp)hop).getOp().name();
if( TemplateUtils.isColVector(cdata1) )
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
if( TemplateUtils.isColVector(cdata2) //vector or vector can be inferred from lhs
|| (TemplateUtils.isColVector(hop.getInput().get(0)) && cdata2 instanceof CNodeData
&& hop.getInput().get(1).getDataType().isMatrix()))
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
}
}
else if(hop instanceof TernaryOp) {
TernaryOp top = (TernaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
if( hop.getDim2() >= 2 ) { //matrices
out = new CNodeBinary(cdata1, new CNodeBinary(cdata2, cdata3, BinType.VECT_MULT_SCALAR),
top.getOp()==OpOp3.PLUS_MULT? BinType.VECT_PLUS : BinType.VECT_MINUS);
}
else { //column vectors
//add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
//construct scalar ternary cnode, primitive operation derived from OpOp3
out = new CNodeTernary(cdata1, cdata2, cdata3,
TernaryType.valueOf(top.getOp().name()));
}
}
else if( HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, OpOpDnn.BIASMULT) ) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
out = new CNodeBinary(cdata1, cdata2,
BinType.valueOf("VECT_"+((DnnOp)hop).getOp().name()));
}
else if( HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, OpOpDnn.AVG_POOL) ) {
CNode[] in = hop.getInput().stream().map(h ->
tmp.get(h.getHopID())).toArray(CNode[]::new);
out = new CNodeNary(in, CNodeNary.NaryType
.valueOf("VECT_"+((DnnOp)hop).getOp().name()));
}
else if( HopRewriteUtils.isDnn(hop, OpOpDnn.CONV2D) ) {
CNode[] in1 = hop.getInput().stream().filter(h -> h!=hop.getInput().get(1))
.map(h ->tmp.get(h.getHopID())).toArray(CNode[]::new);
CNode im2col = new CNodeNary(in1, CNodeNary.NaryType.VECT_IM2COL);
CNode[] in2 = hop.getInput().stream().map(h -> (h==hop.getInput().get(0)) ?
im2col : tmp.get(h.getHopID())).toArray(CNode[]::new);
out = new CNodeNary(in2, CNodeNary.NaryType.VECT_CONV2DMM);
}
else if( hop instanceof NaryOp ) {
CNode[] inputs = new CNode[hop.getInput().size()];
for( int i=0; i<hop.getInput().size(); i++ ) {
Hop c = hop.getInput().get(i);
CNode cdata = tmp.get(c.getHopID());
if( TemplateUtils.isColVector(cdata) || TemplateUtils.isRowVector(cdata) )
cdata = TemplateUtils.wrapLookupIfNecessary(cdata, c);
inputs[i] = cdata;
if( i==0 && cdata instanceof CNodeData && !inHops2.containsKey("X") )
inHops2.put("X", c);
}
if( HopRewriteUtils.isNary(hop, OpOpN.CBIND) ) {
out = new CNodeNary(inputs, NaryType.VECT_CBIND);
}
else if( HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX, OpOpN.PLUS) ) {
out = getVectorOrScalarBinary(inputs[0], inputs[1], ((NaryOp)hop).getOp().name());
for( int i=2; i<hop.getInput().size(); i++ )
out = getVectorOrScalarBinary(out, inputs[i], ((NaryOp)hop).getOp().name());
}
}
else if( hop instanceof ParameterizedBuiltinOp ) {
CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
CNode cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID());
CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID());
TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ?
TernaryType.REPLACE_NAN : TernaryType.REPLACE;
out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
}
else if( hop instanceof IndexingOp ) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
out = new CNodeTernary(cdata1,
TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true),
TemplateUtils.createCNodeData(hop.getInput().get(4), true),
(hop.getDim2() != 1) ? TernaryType.LOOKUP_RVECT1 : TernaryType.LOOKUP_RC1);
}
if( out == null ) {
throw new HopsException(hop.getHopID()+" "+hop.getOpString());
}
if( out.getDataType().isMatrix() ) {
out.setNumRows(hop.getDim1());
out.setNumCols(hop.getDim2());
}
tmp.put(hop.getHopID(), out);
}