private Hop processBuiltinFunctionExpression()

in src/main/java/org/apache/sysds/parser/DMLTranslator.java [2298:2831]


	private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, DataIdentifier target,
			HashMap<String, Hop> hops) {
		Hop expr = null;
		if(source.getFirstExpr() != null){
			expr = processExpression(source.getFirstExpr(), null, hops);
		}
		Hop expr2 = null;
		if (source.getSecondExpr() != null) {
			expr2 = processExpression(source.getSecondExpr(), null, hops);
		}
		Hop expr3 = null;
		if (source.getThirdExpr() != null) {
			expr3 = processExpression(source.getThirdExpr(), null, hops);
		}

		Hop currBuiltinOp = null;
		target = (target == null) ? createTarget(source) : target;

		// Construct the hop based on the type of Builtin function
		switch (source.getOpCode()) {

		case EVAL:
		case EVALLIST:
			currBuiltinOp = new NaryOp(target.getName(), target.getDataType(), target.getValueType(),
				OpOpN.EVAL, processAllExpressions(source.getAllExpr(), hops));
			break;

		case COLSUM:
		case COLMAX:
		case COLMIN:
		case COLMEAN:
		case COLPROD:
		case COLVAR:
			currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
				AggOp.valueOf(source.getOpCode().name().substring(3)), Direction.Col, expr);
			break;

		case COLSD:
			// colStdDevs = sqrt(colVariances)
			currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX,
					target.getValueType(), AggOp.VAR, Direction.Col, expr);
			currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX,
					target.getValueType(), OpOp1.SQRT, currBuiltinOp);
			break;

		case ROWSUM:
		case ROWMIN:
		case ROWMAX:
		case ROWMEAN:
		case ROWPROD:
		case ROWVAR:
			currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
				AggOp.valueOf(source.getOpCode().name().substring(3)), Direction.Row, expr);
			break;

		case ROWINDEXMAX:
			currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), AggOp.MAXINDEX,
					Direction.Row, expr);
			break;

		case ROWINDEXMIN:
			currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), AggOp.MININDEX,
					Direction.Row, expr);
			break;

		case ROWSD:
			// rowStdDevs = sqrt(rowVariances)
			currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX,
					target.getValueType(), AggOp.VAR, Direction.Row, expr);
			currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX,
					target.getValueType(), OpOp1.SQRT, currBuiltinOp);
			break;

		case NROW:
			// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
			// Else create a UnaryOp so that a control program instruction is generated
			currBuiltinOp = (expr.getDim1()==-1) ? new UnaryOp(target.getName(), target.getDataType(),
				target.getValueType(), OpOp1.NROW, expr) : new LiteralOp(expr.getDim1());
			break;
		case NCOL:
			// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
			// Else create a UnaryOp so that a control program instruction is generated
			currBuiltinOp = (expr.getDim2()==-1) ? new UnaryOp(target.getName(), target.getDataType(),
				target.getValueType(), OpOp1.NCOL, expr) : new LiteralOp(expr.getDim2());
			break;
		case LENGTH:
			// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
			// Else create a UnaryOp so that a control program instruction is generated
			currBuiltinOp = (expr.getDim1()==-1 || expr.getDim2()==-1) ? new UnaryOp(target.getName(), target.getDataType(),
				target.getValueType(), OpOp1.LENGTH, expr) : new LiteralOp(expr.getDim1()*expr.getDim2());
			break;

		case LINEAGE:
			//construct hop and enable lineage tracing if necessary
			currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(),
				target.getValueType(), OpOp1.LINEAGE, expr);
			DMLScript.LINEAGE = true;
			break;

		case LIST:
			currBuiltinOp = new NaryOp(target.getName(), DataType.LIST, ValueType.UNKNOWN,
				OpOpN.LIST, processAllExpressions(source.getAllExpr(), hops));
			break;

		case EXISTS:
			currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR,
				target.getValueType(), OpOp1.EXISTS, expr);
			break;

		case SUM:
		case PROD:
		case VAR:
			currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(),
					AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr);
			break;

		case MEAN:
			if ( expr2 == null ) {
				// example: x = mean(Y);
				currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), AggOp.MEAN,
					Direction.RowCol, expr);
			}
			else {
				// example: x = mean(Y,W);
				// stable weighted mean is implemented by using centralMoment with order = 0
				Hop orderHop = new LiteralOp(0);
				currBuiltinOp=new TernaryOp(target.getName(), DataType.SCALAR,
					target.getValueType(), OpOp3.MOMENT, expr, expr2, orderHop);
			}
			break;

		case SD:
			// stdDev = sqrt(variance)
			currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR,
				target.getValueType(), AggOp.VAR, Direction.RowCol, expr);
			HopRewriteUtils.setOutputParametersForScalar(currBuiltinOp);
			currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR,
				target.getValueType(), OpOp1.SQRT, currBuiltinOp);
			break;

		case MIN:
		case MAX:
			//construct AggUnary for min(X) but BinaryOp for min(X,Y) and NaryOp for min(X,Y,Z)
			currBuiltinOp = (expr2 == null) ? 
				new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(),
					AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr) : 
				(source.getAllExpr().length == 2) ?
				new BinaryOp(target.getName(), target.getDataType(), target.getValueType(),
					OpOp2.valueOf(source.getOpCode().name()), expr, expr2) : 
				new NaryOp(target.getName(), target.getDataType(), target.getValueType(),
					OpOpN.valueOf(source.getOpCode().name()), processAllExpressions(source.getAllExpr(), hops));
			break;

		case PPRED:
			String sop = ((StringIdentifier)source.getThirdExpr()).getValue();
			sop = sop.replace("\"", "");
			OpOp2 operation;
			if ( sop.equalsIgnoreCase(Opcodes.GREATEREQUAL.toString()) )
				operation = OpOp2.GREATEREQUAL;
			else if ( sop.equalsIgnoreCase(Opcodes.GREATER.toString()) )
				operation = OpOp2.GREATER;
			else if ( sop.equalsIgnoreCase(Opcodes.LESSEQUAL.toString()) )
				operation = OpOp2.LESSEQUAL;
			else if ( sop.equalsIgnoreCase(Opcodes.LESS.toString()) )
				operation = OpOp2.LESS;
			else if ( sop.equalsIgnoreCase(Opcodes.EQUAL.toString()) )
				operation = OpOp2.EQUAL;
			else if ( sop.equalsIgnoreCase(Opcodes.NOTEQUAL.toString()) )
				operation = OpOp2.NOTEQUAL;
			else {
				throw new ParseException(source.printErrorLocation() + "Unknown argument (" + sop + ") for PPRED.");
			}
			currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), operation, expr, expr2);
			break;

		case TRACE:
			currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), AggOp.TRACE,
					Direction.RowCol, expr);
			break;

		case TRANS:
		case DIAG:
		case REV:
			currBuiltinOp = new ReorgOp(target.getName(), DataType.MATRIX,
				target.getValueType(), ReOrgOp.valueOf(source.getOpCode().name()), expr);
			break;

		case ROLL:
			ArrayList<Hop> inputs = new ArrayList<>();
			inputs.add(expr);
			inputs.add(expr2);
			currBuiltinOp = new ReorgOp(target.getName(), DataType.MATRIX,
					target.getValueType(), ReOrgOp.valueOf(source.getOpCode().name()), inputs);
			break;

		case CBIND:
		case RBIND:
			OpOp2 appendOp2 = (source.getOpCode()==Builtins.CBIND) ? OpOp2.CBIND : OpOp2.RBIND;
			OpOpN appendOpN = (source.getOpCode()==Builtins.CBIND) ? OpOpN.CBIND : OpOpN.RBIND;
			currBuiltinOp = (source.getAllExpr().length == 2) ?
				new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), appendOp2, expr, expr2) :
				new NaryOp(target.getName(), target.getDataType(), target.getValueType(), appendOpN,
					processAllExpressions(source.getAllExpr(), hops));
			break;

		case TABLE:

			// Always a TertiaryOp is created for table().
			// - create a hop for weights, if not provided in the function call.
			int numTableArgs = source._args.length;

			switch(numTableArgs) {
			case 2:
			case 4:
				// example DML statement: F = ctable(A,B) or F = ctable(A,B,10,15)
				// here, weight is interpreted as 1.0
				Hop weightHop = new LiteralOp(1.0);
				// set dimensions
				weightHop.setDim1(0);
				weightHop.setDim2(0);
				weightHop.setNnz(-1);
				weightHop.setBlocksize(0);

				if ( numTableArgs == 2 )
					currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, weightHop);
				else {
					Hop outDim1 = processExpression(source._args[2], null, hops);
					Hop outDim2 = processExpression(source._args[3], null, hops);
					currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(),
						OpOp3.CTABLE, expr, expr2, weightHop, outDim1, outDim2, new LiteralOp(true));
				}
				break;

			case 3:
			case 5:
			case 6:
				// example DML statement: F = ctable(A,B,W) or F = ctable(A,B,W,10,15) 
				if (numTableArgs == 3) 
					currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, expr3);
				else {
					Hop outDim1 = processExpression(source._args[3], null, hops);
					Hop outDim2 = processExpression(source._args[4], null, hops);
					Hop outputEmptyBlocks = numTableArgs == 6 ?
						processExpression(source._args[5], null, hops) : new LiteralOp(true);
					currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(),
						OpOp3.CTABLE, expr, expr2, expr3, outDim1, outDim2, outputEmptyBlocks);
				}
				break;

			default: 
				throw new ParseException("Invalid number of arguments "+ numTableArgs + " to table() function.");
			}
			break;

		//data type casts
		case CAST_AS_SCALAR:
			currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), OpOp1.CAST_AS_SCALAR, expr);
			break;
		case CAST_AS_MATRIX:
			currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), OpOp1.CAST_AS_MATRIX, expr);
			break;
		case CAST_AS_FRAME:
			if(expr2 != null)
				currBuiltinOp = new BinaryOp(target.getName(), DataType.FRAME, target.getValueType(), OpOp2.CAST_AS_FRAME, expr, expr2);
			else
				currBuiltinOp = new UnaryOp(target.getName(), DataType.FRAME, target.getValueType(), OpOp1.CAST_AS_FRAME, expr);
			break;
		case CAST_AS_LIST:
			currBuiltinOp = new UnaryOp(target.getName(), DataType.LIST, target.getValueType(), OpOp1.CAST_AS_LIST, expr);
			break;

		//value type casts
		case CAST_AS_DOUBLE:
			currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.FP64, OpOp1.CAST_AS_DOUBLE, expr);
			break;
		case CAST_AS_INT:
			currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.INT64, OpOp1.CAST_AS_INT, expr);
			break;
		case CAST_AS_BOOLEAN:
			currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.BOOLEAN, OpOp1.CAST_AS_BOOLEAN, expr);
			break;
		case LOCAL:
			currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.FP64, OpOp1.LOCAL, expr);
			break;
		case COMPRESS:
			currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.FP64, OpOp1.COMPRESS, expr);
			break;
		case DECOMPRESS:
			currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.FP64, OpOp1.DECOMPRESS, expr);
			break;

		// Boolean binary
		case XOR:
		case BITWAND:
		case BITWOR:
		case BITWXOR:
		case BITWSHIFTL:
		case BITWSHIFTR:
			currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(),
				target.getValueType(), OpOp2.valueOf(source.getOpCode().name()), expr, expr2);
			break;
		case ABS:
		case SIN:
		case COS:
		case TAN:
		case ASIN:
		case ACOS:
		case ATAN:
		case SINH:
		case COSH:
		case TANH:
		case SIGN:
		case SQRT:
		case EXP:
		case ROUND:
		case CEIL:
		case FLOOR:
		case CUMSUM:
		case CUMPROD:
		case CUMSUMPROD:
		case CUMMIN:
		case CUMMAX:
		case ISNA:
		case ISNAN:
		case ISINF:
			currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(),
				OpOp1.valueOf(source.getOpCode().name()), expr);
			break;
		case DROP_INVALID_TYPE:
		case DROP_INVALID_LENGTH:
		case VALUE_SWAP:
		case FRAME_ROW_REPLICATE:
		case APPLY_SCHEMA:
			currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(),
				target.getValueType(), OpOp2.valueOf(source.getOpCode().name()), expr, expr2);
			break;
		case MAP:
			currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(),
				target.getValueType(), OpOp3.valueOf(source.getOpCode().name()),
				expr, expr2, (expr3==null) ? new LiteralOp(0L) : expr3);
			break;

		case LOG:
				if (expr2 == null) {
					OpOp1 mathOp2;
					switch (source.getOpCode()) {
					case LOG:
						mathOp2 = OpOp1.LOG;
						break;
					default:
						throw new ParseException(source.printErrorLocation() +
							"processBuiltinFunctionExpression():: Could not find Operation type for builtin function: "
							+ source.getOpCode());
					}
					currBuiltinOp = new UnaryOp(target.getName(),
						target.getDataType(), target.getValueType(), mathOp2, expr);
				} else {
					OpOp2 mathOp3;
					switch (source.getOpCode()) {
					case LOG:
						mathOp3 = OpOp2.LOG;
						break;
					default:
						throw new ParseException(source.printErrorLocation() +
							"processBuiltinFunctionExpression():: Could not find Operation type for builtin function: "
								+ source.getOpCode());
					}
					currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp3,
							expr, expr2);
				}
			break;

		case MOMENT:
		case COV:
		case QUANTILE:
		case INTERQUANTILE:
			currBuiltinOp = (expr3 == null) ? new BinaryOp(target.getName(), target.getDataType(), target.getValueType(),
				OpOp2.valueOf(source.getOpCode().name()), expr, expr2) :  new TernaryOp(target.getName(), target.getDataType(),
				target.getValueType(), OpOp3.valueOf(source.getOpCode().name()), expr, expr2,expr3);
			break;

		case IQM:
		case MEDIAN:
			currBuiltinOp = (expr2 == null) ? new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), 
				OpOp1.valueOf(source.getOpCode().name()), expr) : new BinaryOp(target.getName(), target.getDataType(),
				target.getValueType(), OpOp2.valueOf(source.getOpCode().name()), expr, expr2);
			break;

		case IFELSE:
			currBuiltinOp=new TernaryOp(target.getName(), target.getDataType(),
				target.getValueType(), OpOp3.IFELSE, expr, expr2, expr3);
			break;

		case SEQ:
			HashMap<String,Hop> randParams = new HashMap<>();
			randParams.put(Statement.SEQ_FROM, expr);
			randParams.put(Statement.SEQ_TO, expr2);
			randParams.put(Statement.SEQ_INCR, (expr3!=null)?expr3 : new LiteralOp(1)); 
			//note incr: default -1 (for from>to) handled during runtime
			currBuiltinOp = new DataGenOp(OpOpDG.SEQ, target, randParams);
			break;

		case TIME:
			currBuiltinOp = new DataGenOp(OpOpDG.TIME, target);
			break;

		case SAMPLE:
		{
			Expression[] in = source.getAllExpr();

			// arguments: range/size/replace/seed; defaults: replace=FALSE

			HashMap<String,Hop> tmpparams = new HashMap<>();
			tmpparams.put(DataExpression.RAND_MAX, expr); //range
			tmpparams.put(DataExpression.RAND_ROWS, expr2);
			tmpparams.put(DataExpression.RAND_COLS, new LiteralOp(1));

			if ( in.length == 4 ) 
			{
				tmpparams.put(DataExpression.RAND_PDF, expr3);
				Hop seed = processExpression(in[3], null, hops);
				tmpparams.put(DataExpression.RAND_SEED, seed);
			}
			else if ( in.length == 3 )
			{
				// check if the third argument is "replace" or "seed"
				if ( expr3.getValueType() == ValueType.BOOLEAN ) 
				{
					tmpparams.put(DataExpression.RAND_PDF, expr3);
					tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) );
				}
				else if ( expr3.getValueType() == ValueType.INT64 ) 
				{
					tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false));
					tmpparams.put(DataExpression.RAND_SEED, expr3 );
				}
				else 
					throw new HopsException("Invalid input type " + expr3.getValueType() + " in sample().");

			}
			else if ( in.length == 2 )
			{
				tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false));
				tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) );
			}

			currBuiltinOp = new DataGenOp(OpOpDG.SAMPLE, target, tmpparams);
			break;
		}

		case SOLVE:
			currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.SOLVE, expr, expr2);
			break;

		case INVERSE:
		case SQRT_MATRIX_JAVA:
		case CHOLESKY:
		case TYPEOF:
		case DET:
		case DETECTSCHEMA:
		case COLNAMES:
			currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(),
				target.getValueType(), OpOp1.valueOf(source.getOpCode().name()), expr);
			break;

		case OUTER:
			if( !(expr3 instanceof LiteralOp) )
				throw new HopsException("Operator for outer builtin function must be a constant: "+expr3);
			OpOp2 op = OpOp2.valueOfByOpcode(((LiteralOp)expr3).getStringValue());
			if( op == null )
				throw new HopsException("Unsupported outer vector binary operation: "+((LiteralOp)expr3).getStringValue());

			currBuiltinOp = new BinaryOp(target.getName(), DataType.MATRIX, target.getValueType(), op, expr, expr2, true);
			currBuiltinOp.refreshSizeInformation(); //force size reevaluation according to 'outer' flag otherwise danger of incorrect dims
			break;

		case BIASADD:
		case BIASMULT: {
			ArrayList<Hop> inHops1 = new ArrayList<>();
			inHops1.add(expr);
			inHops1.add(expr2);
			currBuiltinOp = new DnnOp(target.getName(), DataType.MATRIX, target.getValueType(),
				OpOpDnn.valueOf(source.getOpCode().name()), inHops1);
			setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
			break;
		}
		case AVG_POOL:
		case MAX_POOL: {
			currBuiltinOp = new DnnOp(target.getName(), DataType.MATRIX, target.getValueType(),
				OpOpDnn.valueOf(source.getOpCode().name()), getALHopsForPoolingForwardIM2COL(expr, source, 1, hops));
			setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
			break;
		}
		case AVG_POOL_BACKWARD:
		case MAX_POOL_BACKWARD: {
			currBuiltinOp = new DnnOp(target.getName(), DataType.MATRIX, target.getValueType(),
				OpOpDnn.valueOf(source.getOpCode().name()), getALHopsForConvOpPoolingCOL2IM(expr, source, 1, hops));
			setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
			break;
		}
		case CONV2D:
		case CONV2D_BACKWARD_FILTER:
		case CONV2D_BACKWARD_DATA: {
			currBuiltinOp = new DnnOp(target.getName(), DataType.MATRIX, target.getValueType(),
				OpOpDnn.valueOf(source.getOpCode().name()), getALHopsForConvOp(expr, source, 1, hops));
			setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
			break;
		}

		case ROW_COUNT_DISTINCT:
			currBuiltinOp = new AggUnaryOp(target.getName(),
				DataType.MATRIX, target.getValueType(), AggOp.COUNT_DISTINCT, Direction.Row, expr);
			break;

		case COL_COUNT_DISTINCT:
			currBuiltinOp = new AggUnaryOp(target.getName(),
				DataType.MATRIX, target.getValueType(), AggOp.COUNT_DISTINCT, Direction.Col, expr);
			break;

		default:
			throw new ParseException("Unsupported builtin function type: "+source.getOpCode());
		}

		boolean isConvolution = source.getOpCode() == Builtins.CONV2D || source.getOpCode() == Builtins.CONV2D_BACKWARD_DATA ||
			source.getOpCode() == Builtins.CONV2D_BACKWARD_FILTER || 
			source.getOpCode() == Builtins.MAX_POOL || source.getOpCode() == Builtins.MAX_POOL_BACKWARD || 
			source.getOpCode() == Builtins.AVG_POOL || source.getOpCode() == Builtins.AVG_POOL_BACKWARD;
		if( !isConvolution) {
			// Since the dimension of output doesnot match that of input variable for these operations
			setIdentifierParams(currBuiltinOp, source.getOutput());
		}
		currBuiltinOp.setParseInfo(source);
		return currBuiltinOp;
	}