in src/main/java/org/apache/sysds/parser/DMLTranslator.java [2298:2831]
private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, DataIdentifier target,
HashMap<String, Hop> hops) {
Hop expr = null;
if(source.getFirstExpr() != null){
expr = processExpression(source.getFirstExpr(), null, hops);
}
Hop expr2 = null;
if (source.getSecondExpr() != null) {
expr2 = processExpression(source.getSecondExpr(), null, hops);
}
Hop expr3 = null;
if (source.getThirdExpr() != null) {
expr3 = processExpression(source.getThirdExpr(), null, hops);
}
Hop currBuiltinOp = null;
target = (target == null) ? createTarget(source) : target;
// Construct the hop based on the type of Builtin function
switch (source.getOpCode()) {
case EVAL:
case EVALLIST:
currBuiltinOp = new NaryOp(target.getName(), target.getDataType(), target.getValueType(),
OpOpN.EVAL, processAllExpressions(source.getAllExpr(), hops));
break;
case COLSUM:
case COLMAX:
case COLMIN:
case COLMEAN:
case COLPROD:
case COLVAR:
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
AggOp.valueOf(source.getOpCode().name().substring(3)), Direction.Col, expr);
break;
case COLSD:
// colStdDevs = sqrt(colVariances)
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX,
target.getValueType(), AggOp.VAR, Direction.Col, expr);
currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX,
target.getValueType(), OpOp1.SQRT, currBuiltinOp);
break;
case ROWSUM:
case ROWMIN:
case ROWMAX:
case ROWMEAN:
case ROWPROD:
case ROWVAR:
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
AggOp.valueOf(source.getOpCode().name().substring(3)), Direction.Row, expr);
break;
case ROWINDEXMAX:
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), AggOp.MAXINDEX,
Direction.Row, expr);
break;
case ROWINDEXMIN:
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), AggOp.MININDEX,
Direction.Row, expr);
break;
case ROWSD:
// rowStdDevs = sqrt(rowVariances)
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX,
target.getValueType(), AggOp.VAR, Direction.Row, expr);
currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX,
target.getValueType(), OpOp1.SQRT, currBuiltinOp);
break;
case NROW:
// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
// Else create a UnaryOp so that a control program instruction is generated
currBuiltinOp = (expr.getDim1()==-1) ? new UnaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp1.NROW, expr) : new LiteralOp(expr.getDim1());
break;
case NCOL:
// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
// Else create a UnaryOp so that a control program instruction is generated
currBuiltinOp = (expr.getDim2()==-1) ? new UnaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp1.NCOL, expr) : new LiteralOp(expr.getDim2());
break;
case LENGTH:
// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
// Else create a UnaryOp so that a control program instruction is generated
currBuiltinOp = (expr.getDim1()==-1 || expr.getDim2()==-1) ? new UnaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp1.LENGTH, expr) : new LiteralOp(expr.getDim1()*expr.getDim2());
break;
case LINEAGE:
//construct hop and enable lineage tracing if necessary
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp1.LINEAGE, expr);
DMLScript.LINEAGE = true;
break;
case LIST:
currBuiltinOp = new NaryOp(target.getName(), DataType.LIST, ValueType.UNKNOWN,
OpOpN.LIST, processAllExpressions(source.getAllExpr(), hops));
break;
case EXISTS:
currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR,
target.getValueType(), OpOp1.EXISTS, expr);
break;
case SUM:
case PROD:
case VAR:
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(),
AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr);
break;
case MEAN:
if ( expr2 == null ) {
// example: x = mean(Y);
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), AggOp.MEAN,
Direction.RowCol, expr);
}
else {
// example: x = mean(Y,W);
// stable weighted mean is implemented by using centralMoment with order = 0
Hop orderHop = new LiteralOp(0);
currBuiltinOp=new TernaryOp(target.getName(), DataType.SCALAR,
target.getValueType(), OpOp3.MOMENT, expr, expr2, orderHop);
}
break;
case SD:
// stdDev = sqrt(variance)
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR,
target.getValueType(), AggOp.VAR, Direction.RowCol, expr);
HopRewriteUtils.setOutputParametersForScalar(currBuiltinOp);
currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR,
target.getValueType(), OpOp1.SQRT, currBuiltinOp);
break;
case MIN:
case MAX:
//construct AggUnary for min(X) but BinaryOp for min(X,Y) and NaryOp for min(X,Y,Z)
currBuiltinOp = (expr2 == null) ?
new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(),
AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr) :
(source.getAllExpr().length == 2) ?
new BinaryOp(target.getName(), target.getDataType(), target.getValueType(),
OpOp2.valueOf(source.getOpCode().name()), expr, expr2) :
new NaryOp(target.getName(), target.getDataType(), target.getValueType(),
OpOpN.valueOf(source.getOpCode().name()), processAllExpressions(source.getAllExpr(), hops));
break;
case PPRED:
String sop = ((StringIdentifier)source.getThirdExpr()).getValue();
sop = sop.replace("\"", "");
OpOp2 operation;
if ( sop.equalsIgnoreCase(Opcodes.GREATEREQUAL.toString()) )
operation = OpOp2.GREATEREQUAL;
else if ( sop.equalsIgnoreCase(Opcodes.GREATER.toString()) )
operation = OpOp2.GREATER;
else if ( sop.equalsIgnoreCase(Opcodes.LESSEQUAL.toString()) )
operation = OpOp2.LESSEQUAL;
else if ( sop.equalsIgnoreCase(Opcodes.LESS.toString()) )
operation = OpOp2.LESS;
else if ( sop.equalsIgnoreCase(Opcodes.EQUAL.toString()) )
operation = OpOp2.EQUAL;
else if ( sop.equalsIgnoreCase(Opcodes.NOTEQUAL.toString()) )
operation = OpOp2.NOTEQUAL;
else {
throw new ParseException(source.printErrorLocation() + "Unknown argument (" + sop + ") for PPRED.");
}
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), operation, expr, expr2);
break;
case TRACE:
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), AggOp.TRACE,
Direction.RowCol, expr);
break;
case TRANS:
case DIAG:
case REV:
currBuiltinOp = new ReorgOp(target.getName(), DataType.MATRIX,
target.getValueType(), ReOrgOp.valueOf(source.getOpCode().name()), expr);
break;
case ROLL:
ArrayList<Hop> inputs = new ArrayList<>();
inputs.add(expr);
inputs.add(expr2);
currBuiltinOp = new ReorgOp(target.getName(), DataType.MATRIX,
target.getValueType(), ReOrgOp.valueOf(source.getOpCode().name()), inputs);
break;
case CBIND:
case RBIND:
OpOp2 appendOp2 = (source.getOpCode()==Builtins.CBIND) ? OpOp2.CBIND : OpOp2.RBIND;
OpOpN appendOpN = (source.getOpCode()==Builtins.CBIND) ? OpOpN.CBIND : OpOpN.RBIND;
currBuiltinOp = (source.getAllExpr().length == 2) ?
new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), appendOp2, expr, expr2) :
new NaryOp(target.getName(), target.getDataType(), target.getValueType(), appendOpN,
processAllExpressions(source.getAllExpr(), hops));
break;
case TABLE:
// Always a TertiaryOp is created for table().
// - create a hop for weights, if not provided in the function call.
int numTableArgs = source._args.length;
switch(numTableArgs) {
case 2:
case 4:
// example DML statement: F = ctable(A,B) or F = ctable(A,B,10,15)
// here, weight is interpreted as 1.0
Hop weightHop = new LiteralOp(1.0);
// set dimensions
weightHop.setDim1(0);
weightHop.setDim2(0);
weightHop.setNnz(-1);
weightHop.setBlocksize(0);
if ( numTableArgs == 2 )
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, weightHop);
else {
Hop outDim1 = processExpression(source._args[2], null, hops);
Hop outDim2 = processExpression(source._args[3], null, hops);
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(),
OpOp3.CTABLE, expr, expr2, weightHop, outDim1, outDim2, new LiteralOp(true));
}
break;
case 3:
case 5:
case 6:
// example DML statement: F = ctable(A,B,W) or F = ctable(A,B,W,10,15)
if (numTableArgs == 3)
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, expr3);
else {
Hop outDim1 = processExpression(source._args[3], null, hops);
Hop outDim2 = processExpression(source._args[4], null, hops);
Hop outputEmptyBlocks = numTableArgs == 6 ?
processExpression(source._args[5], null, hops) : new LiteralOp(true);
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(),
OpOp3.CTABLE, expr, expr2, expr3, outDim1, outDim2, outputEmptyBlocks);
}
break;
default:
throw new ParseException("Invalid number of arguments "+ numTableArgs + " to table() function.");
}
break;
//data type casts
case CAST_AS_SCALAR:
currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), OpOp1.CAST_AS_SCALAR, expr);
break;
case CAST_AS_MATRIX:
currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), OpOp1.CAST_AS_MATRIX, expr);
break;
case CAST_AS_FRAME:
if(expr2 != null)
currBuiltinOp = new BinaryOp(target.getName(), DataType.FRAME, target.getValueType(), OpOp2.CAST_AS_FRAME, expr, expr2);
else
currBuiltinOp = new UnaryOp(target.getName(), DataType.FRAME, target.getValueType(), OpOp1.CAST_AS_FRAME, expr);
break;
case CAST_AS_LIST:
currBuiltinOp = new UnaryOp(target.getName(), DataType.LIST, target.getValueType(), OpOp1.CAST_AS_LIST, expr);
break;
//value type casts
case CAST_AS_DOUBLE:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.FP64, OpOp1.CAST_AS_DOUBLE, expr);
break;
case CAST_AS_INT:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.INT64, OpOp1.CAST_AS_INT, expr);
break;
case CAST_AS_BOOLEAN:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.BOOLEAN, OpOp1.CAST_AS_BOOLEAN, expr);
break;
case LOCAL:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.FP64, OpOp1.LOCAL, expr);
break;
case COMPRESS:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.FP64, OpOp1.COMPRESS, expr);
break;
case DECOMPRESS:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.FP64, OpOp1.DECOMPRESS, expr);
break;
// Boolean binary
case XOR:
case BITWAND:
case BITWOR:
case BITWXOR:
case BITWSHIFTL:
case BITWSHIFTR:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp2.valueOf(source.getOpCode().name()), expr, expr2);
break;
case ABS:
case SIN:
case COS:
case TAN:
case ASIN:
case ACOS:
case ATAN:
case SINH:
case COSH:
case TANH:
case SIGN:
case SQRT:
case EXP:
case ROUND:
case CEIL:
case FLOOR:
case CUMSUM:
case CUMPROD:
case CUMSUMPROD:
case CUMMIN:
case CUMMAX:
case ISNA:
case ISNAN:
case ISINF:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(),
OpOp1.valueOf(source.getOpCode().name()), expr);
break;
case DROP_INVALID_TYPE:
case DROP_INVALID_LENGTH:
case VALUE_SWAP:
case FRAME_ROW_REPLICATE:
case APPLY_SCHEMA:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp2.valueOf(source.getOpCode().name()), expr, expr2);
break;
case MAP:
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp3.valueOf(source.getOpCode().name()),
expr, expr2, (expr3==null) ? new LiteralOp(0L) : expr3);
break;
case LOG:
if (expr2 == null) {
OpOp1 mathOp2;
switch (source.getOpCode()) {
case LOG:
mathOp2 = OpOp1.LOG;
break;
default:
throw new ParseException(source.printErrorLocation() +
"processBuiltinFunctionExpression():: Could not find Operation type for builtin function: "
+ source.getOpCode());
}
currBuiltinOp = new UnaryOp(target.getName(),
target.getDataType(), target.getValueType(), mathOp2, expr);
} else {
OpOp2 mathOp3;
switch (source.getOpCode()) {
case LOG:
mathOp3 = OpOp2.LOG;
break;
default:
throw new ParseException(source.printErrorLocation() +
"processBuiltinFunctionExpression():: Could not find Operation type for builtin function: "
+ source.getOpCode());
}
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp3,
expr, expr2);
}
break;
case MOMENT:
case COV:
case QUANTILE:
case INTERQUANTILE:
currBuiltinOp = (expr3 == null) ? new BinaryOp(target.getName(), target.getDataType(), target.getValueType(),
OpOp2.valueOf(source.getOpCode().name()), expr, expr2) : new TernaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp3.valueOf(source.getOpCode().name()), expr, expr2,expr3);
break;
case IQM:
case MEDIAN:
currBuiltinOp = (expr2 == null) ? new UnaryOp(target.getName(), target.getDataType(), target.getValueType(),
OpOp1.valueOf(source.getOpCode().name()), expr) : new BinaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp2.valueOf(source.getOpCode().name()), expr, expr2);
break;
case IFELSE:
currBuiltinOp=new TernaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp3.IFELSE, expr, expr2, expr3);
break;
case SEQ:
HashMap<String,Hop> randParams = new HashMap<>();
randParams.put(Statement.SEQ_FROM, expr);
randParams.put(Statement.SEQ_TO, expr2);
randParams.put(Statement.SEQ_INCR, (expr3!=null)?expr3 : new LiteralOp(1));
//note incr: default -1 (for from>to) handled during runtime
currBuiltinOp = new DataGenOp(OpOpDG.SEQ, target, randParams);
break;
case TIME:
currBuiltinOp = new DataGenOp(OpOpDG.TIME, target);
break;
case SAMPLE:
{
Expression[] in = source.getAllExpr();
// arguments: range/size/replace/seed; defaults: replace=FALSE
HashMap<String,Hop> tmpparams = new HashMap<>();
tmpparams.put(DataExpression.RAND_MAX, expr); //range
tmpparams.put(DataExpression.RAND_ROWS, expr2);
tmpparams.put(DataExpression.RAND_COLS, new LiteralOp(1));
if ( in.length == 4 )
{
tmpparams.put(DataExpression.RAND_PDF, expr3);
Hop seed = processExpression(in[3], null, hops);
tmpparams.put(DataExpression.RAND_SEED, seed);
}
else if ( in.length == 3 )
{
// check if the third argument is "replace" or "seed"
if ( expr3.getValueType() == ValueType.BOOLEAN )
{
tmpparams.put(DataExpression.RAND_PDF, expr3);
tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) );
}
else if ( expr3.getValueType() == ValueType.INT64 )
{
tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false));
tmpparams.put(DataExpression.RAND_SEED, expr3 );
}
else
throw new HopsException("Invalid input type " + expr3.getValueType() + " in sample().");
}
else if ( in.length == 2 )
{
tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false));
tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) );
}
currBuiltinOp = new DataGenOp(OpOpDG.SAMPLE, target, tmpparams);
break;
}
case SOLVE:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.SOLVE, expr, expr2);
break;
case INVERSE:
case SQRT_MATRIX_JAVA:
case CHOLESKY:
case TYPEOF:
case DET:
case DETECTSCHEMA:
case COLNAMES:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp1.valueOf(source.getOpCode().name()), expr);
break;
case OUTER:
if( !(expr3 instanceof LiteralOp) )
throw new HopsException("Operator for outer builtin function must be a constant: "+expr3);
OpOp2 op = OpOp2.valueOfByOpcode(((LiteralOp)expr3).getStringValue());
if( op == null )
throw new HopsException("Unsupported outer vector binary operation: "+((LiteralOp)expr3).getStringValue());
currBuiltinOp = new BinaryOp(target.getName(), DataType.MATRIX, target.getValueType(), op, expr, expr2, true);
currBuiltinOp.refreshSizeInformation(); //force size reevaluation according to 'outer' flag otherwise danger of incorrect dims
break;
case BIASADD:
case BIASMULT: {
ArrayList<Hop> inHops1 = new ArrayList<>();
inHops1.add(expr);
inHops1.add(expr2);
currBuiltinOp = new DnnOp(target.getName(), DataType.MATRIX, target.getValueType(),
OpOpDnn.valueOf(source.getOpCode().name()), inHops1);
setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
break;
}
case AVG_POOL:
case MAX_POOL: {
currBuiltinOp = new DnnOp(target.getName(), DataType.MATRIX, target.getValueType(),
OpOpDnn.valueOf(source.getOpCode().name()), getALHopsForPoolingForwardIM2COL(expr, source, 1, hops));
setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
break;
}
case AVG_POOL_BACKWARD:
case MAX_POOL_BACKWARD: {
currBuiltinOp = new DnnOp(target.getName(), DataType.MATRIX, target.getValueType(),
OpOpDnn.valueOf(source.getOpCode().name()), getALHopsForConvOpPoolingCOL2IM(expr, source, 1, hops));
setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
break;
}
case CONV2D:
case CONV2D_BACKWARD_FILTER:
case CONV2D_BACKWARD_DATA: {
currBuiltinOp = new DnnOp(target.getName(), DataType.MATRIX, target.getValueType(),
OpOpDnn.valueOf(source.getOpCode().name()), getALHopsForConvOp(expr, source, 1, hops));
setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
break;
}
case ROW_COUNT_DISTINCT:
currBuiltinOp = new AggUnaryOp(target.getName(),
DataType.MATRIX, target.getValueType(), AggOp.COUNT_DISTINCT, Direction.Row, expr);
break;
case COL_COUNT_DISTINCT:
currBuiltinOp = new AggUnaryOp(target.getName(),
DataType.MATRIX, target.getValueType(), AggOp.COUNT_DISTINCT, Direction.Col, expr);
break;
default:
throw new ParseException("Unsupported builtin function type: "+source.getOpCode());
}
boolean isConvolution = source.getOpCode() == Builtins.CONV2D || source.getOpCode() == Builtins.CONV2D_BACKWARD_DATA ||
source.getOpCode() == Builtins.CONV2D_BACKWARD_FILTER ||
source.getOpCode() == Builtins.MAX_POOL || source.getOpCode() == Builtins.MAX_POOL_BACKWARD ||
source.getOpCode() == Builtins.AVG_POOL || source.getOpCode() == Builtins.AVG_POOL_BACKWARD;
if( !isConvolution) {
// Since the dimension of output doesnot match that of input variable for these operations
setIdentifierParams(currBuiltinOp, source.getOutput());
}
currBuiltinOp.setParseInfo(source);
return currBuiltinOp;
}