private void rConstructCplan()

in src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java [278:553]


	private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) 
	{
		//memoization for common subexpression elimination and to avoid redundant work 
		if( tmp.containsKey(hop.getHopID()) )
			return;
		
		//recursively process required childs
		MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.ROW, TemplateType.CELL);
		for( int i=0; i<hop.getInput().size(); i++ ) {
			Hop c = hop.getInput().get(i);
			if( me!=null && me.isPlanRef(i) )
				rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
			else {
				CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
				tmp.put(c.getHopID(), cdata);
				inHops.add(c);
			}
		}
		
		//construct cnode for current hop
		CNode out = null;
		if(hop instanceof AggUnaryOp)
		{
			CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
			if( ((AggUnaryOp)hop).getDirection().isRow() && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG) ) {
				if(hop.getInput().get(0).getDim2()==1)
					out = (cdata1.getDataType()==DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1,UnaryType.LOOKUP_R);
				else {
					String opcode = "ROW_"+((AggUnaryOp)hop).getOp().name().toUpperCase()+"S";
					out = new CNodeUnary(cdata1, UnaryType.valueOf(opcode));
					if( cdata1 instanceof CNodeData && !inHops2.containsKey("X") )
						inHops2.put("X", hop.getInput().get(0));
				}
			}
			else if ( HopRewriteUtils.isAggUnaryOp(hop, AggOp.SUM, AggOp.MEAN) 
				&& ((AggUnaryOp)hop).getDirection().isCol() ) { //closes row template
				//vector add without temporary copy
				if( cdata1 instanceof CNodeBinary && ((CNodeBinary)cdata1).getType().isVectorScalarPrimitive() )
					out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1),
							((CNodeBinary)cdata1).getType().getVectorAddPrimitive());
				else
					out = cdata1;
			}
			else if( ((AggUnaryOp)hop).getDirection() == Direction.RowCol && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) {
				out = (cdata1.getDataType().isMatrix()) ?
					new CNodeUnary(cdata1, UnaryType.ROW_SUMS) : cdata1;
			}
		}
		else if(hop instanceof AggBinaryOp)
		{
			CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
			CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
			
			if( HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)) )
			{
				//correct input under transpose
				cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
				inHops.remove(hop.getInput().get(0));
				if( cdata1 instanceof CNodeData )
					inHops.add(hop.getInput().get(0).getInput().get(0));
				
				//note: vectorMultAdd applicable to vector-scalar, and vector-vector
				if( hop.getInput().get(1).getDim2() == 1 )
					out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
				else {
					out = new CNodeBinary(cdata1, cdata2, BinType.VECT_OUTERMULT_ADD);
					if( !inHops2.containsKey("B1") ) { //incl modification of X for consistency
						if( cdata1 instanceof CNodeData )
							inHops2.put("X", hop.getInput().get(0).getInput().get(0));
						inHops2.put("B1", hop.getInput().get(1));
					}
				}
				if( !inHops2.containsKey("X") )
					inHops2.put("X", hop.getInput().get(0).getInput().get(0));
			}
			else
			{
				if(hop.getInput().get(0).getDim2()==1 && hop.getInput().get(1).getDim2()==1)
					out = new CNodeBinary((cdata1.getDataType()==DataType.SCALAR)? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0),
						(cdata2.getDataType()==DataType.SCALAR)? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT);
				else if( hop.getInput().get(1).getDim2()==1 ) {
					out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
					inHops2.put("X", hop.getInput().get(0));
				}
				else {
					out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MATRIXMULT);
					inHops2.put("X", hop.getInput().get(0));
					inHops2.put("B1", hop.getInput().get(1));
				}
			}
		}
		else if( HopRewriteUtils.isDataGenOp(hop, OpOpDG.SEQ) ) {
			CNodeData from = TemplateUtils.getLiteral(tmp.get(((DataGenOp)hop).getParam(Statement.SEQ_FROM).getHopID()));
			CNodeData to = TemplateUtils.getLiteral(tmp.get(((DataGenOp)hop).getParam(Statement.SEQ_TO).getHopID()));
			CNodeData incr = TemplateUtils.getLiteral(tmp.get(((DataGenOp)hop).getParam(Statement.SEQ_INCR).getHopID()));
			if( Double.parseDouble(from.getVarname()) > Double.parseDouble(to.getVarname())
				&& Double.parseDouble(incr.getVarname()) > 0 ) {
				incr = TemplateUtils.createCNodeData(new LiteralOp("-"+incr.getVarname()), true);
			}
			out = new CNodeBinary(from, incr, BinType.SEQ_RIX);
		}
		else if( HopRewriteUtils.isTransposeOperation(hop) ) {
			out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()),
				hop, tmp, compileLiterals);
			if( out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)) )
				inHops.add(hop.getInput().get(0));
		}
		else if(hop instanceof UnaryOp) {
			CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
			
			// if one input is a matrix then we need to do vector by scalar operations
			if(hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1 
				|| (!hop.dimsKnown() && cdata1.getDataType()==DataType.MATRIX ) ) 
			{
				if( HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY) ) {
					String opname = "VECT_"+((UnaryOp)hop).getOp().name();
					out = new CNodeUnary(cdata1, UnaryType.valueOf(opname));
					if( cdata1 instanceof CNodeData && !inHops2.containsKey("X") )
						inHops2.put("X", hop.getInput().get(0));
				}
				else 
					throw new RuntimeException("Unsupported unary matrix "
						+ "operation: " + ((UnaryOp)hop).getOp().name());
			}
			else //general scalar case
			{
				cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
				String primitiveOpName = ((UnaryOp)hop).getOp().name();
				out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
			}
		}
		else if(HopRewriteUtils.isBinary(hop, OpOp2.CBIND)) {
			//special case for cbind with zeros
			CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
			CNode cdata2 = null;
			if( HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1)) ) {
				cdata2 = TemplateUtils.createCNodeData(HopRewriteUtils
					.getDataGenOpConstantValue(hop.getInput().get(1)), true);
				inHops.remove(hop.getInput().get(1)); //rm 0-matrix
			}
			else {
				cdata2 = tmp.get(hop.getInput().get(1).getHopID());
				cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1), true);
			}
			out = new CNodeBinary(cdata1, cdata2, BinType.VECT_CBIND);
			if( cdata1 instanceof CNodeData && !inHops2.containsKey("X") )
				inHops2.put("X", hop.getInput().get(0));
		}
		else if(hop instanceof BinaryOp)
		{
			CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
			CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
			
			// if one input is a matrix then we need to do vector by scalar operations
			if( (hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1)
				|| (hop.getInput().get(1).getDim1() >= 1 && hop.getInput().get(1).getDim2() > 1)
				|| (!(hop.dimsKnown() && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown())
					&& (hop.getDim2() != 1) //not a known vector output
					&& (cdata1.getDataType().isMatrix() || cdata2.getDataType().isMatrix())))
			{
				if( HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY) ) {
					if( TemplateUtils.isColVector(cdata1) )
						cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
					if( TemplateUtils.isColVector(cdata2) )
						cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
					out = getVectorBinary(cdata1, cdata2, ((BinaryOp)hop).getOp().name());
					if( cdata1 instanceof CNodeData && !inHops2.containsKey("X")
						&& !(cdata1.getDataType()==DataType.SCALAR) ) {
						inHops2.put("X", hop.getInput().get(0));
					}
				}
				else 
					throw new RuntimeException("Unsupported binary matrix "
						+ "operation: " + ((BinaryOp)hop).getOp().name());
			}
			else //one input is a vector/scalar other is a scalar
			{
				String primitiveOpName = ((BinaryOp)hop).getOp().name();
				if( TemplateUtils.isColVector(cdata1) )
					cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
				if( TemplateUtils.isColVector(cdata2) //vector or vector can be inferred from lhs
					|| (TemplateUtils.isColVector(hop.getInput().get(0)) && cdata2 instanceof CNodeData
						&& hop.getInput().get(1).getDataType().isMatrix()))
					cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
				out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
			}
		}
		else if(hop instanceof TernaryOp) {
			TernaryOp top = (TernaryOp) hop;
			CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
			CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
			CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
			
			if( hop.getDim2() >= 2 ) { //matrices
				out = new CNodeBinary(cdata1, new CNodeBinary(cdata2, cdata3, BinType.VECT_MULT_SCALAR),
					top.getOp()==OpOp3.PLUS_MULT? BinType.VECT_PLUS : BinType.VECT_MINUS);
			}
			else { //column vectors
				//add lookups if required
				cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
				cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
				cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
				
				//construct scalar ternary cnode, primitive operation derived from OpOp3 
				out = new CNodeTernary(cdata1, cdata2, cdata3, 
					TernaryType.valueOf(top.getOp().name()));
			}
		}
		else if( HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, OpOpDnn.BIASMULT) ) {
			CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
			CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
			out = new CNodeBinary(cdata1, cdata2,
				BinType.valueOf("VECT_"+((DnnOp)hop).getOp().name()));
		}
		else if( HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, OpOpDnn.AVG_POOL) ) {
			CNode[] in = hop.getInput().stream().map(h ->
				tmp.get(h.getHopID())).toArray(CNode[]::new);
			out = new CNodeNary(in, CNodeNary.NaryType
				.valueOf("VECT_"+((DnnOp)hop).getOp().name()));
		}
		else if( HopRewriteUtils.isDnn(hop, OpOpDnn.CONV2D) ) {
			CNode[] in1 = hop.getInput().stream().filter(h -> h!=hop.getInput().get(1))
				.map(h ->tmp.get(h.getHopID())).toArray(CNode[]::new);
			CNode im2col = new CNodeNary(in1, CNodeNary.NaryType.VECT_IM2COL);
			CNode[] in2 = hop.getInput().stream().map(h -> (h==hop.getInput().get(0)) ?
				im2col : tmp.get(h.getHopID())).toArray(CNode[]::new);
			out = new CNodeNary(in2, CNodeNary.NaryType.VECT_CONV2DMM);
		}
		else if( hop instanceof NaryOp ) {
			CNode[] inputs = new CNode[hop.getInput().size()];
			for( int i=0; i<hop.getInput().size(); i++ ) {
				Hop c = hop.getInput().get(i);
				CNode cdata = tmp.get(c.getHopID());
				if( TemplateUtils.isColVector(cdata) || TemplateUtils.isRowVector(cdata) )
					cdata = TemplateUtils.wrapLookupIfNecessary(cdata, c);
				inputs[i] = cdata;
				if( i==0 && cdata instanceof CNodeData && !inHops2.containsKey("X") )
					inHops2.put("X", c);
			}
			if( HopRewriteUtils.isNary(hop, OpOpN.CBIND) ) {
				out = new CNodeNary(inputs, NaryType.VECT_CBIND);
			}
			else if( HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX, OpOpN.PLUS) ) {
				out = getVectorOrScalarBinary(inputs[0], inputs[1], ((NaryOp)hop).getOp().name());
				for( int i=2; i<hop.getInput().size(); i++ )
					out = getVectorOrScalarBinary(out, inputs[i], ((NaryOp)hop).getOp().name());
			}
		}
		else if( hop instanceof ParameterizedBuiltinOp ) {
			CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
			cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
			CNode cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID());
			CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID());
			TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ?
				TernaryType.REPLACE_NAN : TernaryType.REPLACE;
			out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
		}
		else if( hop instanceof IndexingOp ) {
			CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
			out = new CNodeTernary(cdata1, 
				TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true),
				TemplateUtils.createCNodeData(hop.getInput().get(4), true),
				(hop.getDim2() != 1) ? TernaryType.LOOKUP_RVECT1 : TernaryType.LOOKUP_RC1);
		}
		
		if( out == null ) {
			throw new HopsException(hop.getHopID()+" "+hop.getOpString());
		}
		
		if( out.getDataType().isMatrix() ) {
			out.setNumRows(hop.getDim1());
			out.setNumCols(hop.getDim2());
		}
		
		tmp.put(hop.getHopID(), out);
	}