in src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java [877:2066]
public void validateExpression(HashMap<String, DataIdentifier> ids, HashMap<String, ConstIdentifier> constVars, boolean conditional)
{
for(int i=0; i < _args.length; i++ ) {
if (_args[i] instanceof FunctionCallIdentifier){
raiseValidateError("UDF function call not supported as parameter to built-in function call", false);
}
_args[i].validateExpression(ids, constVars, conditional);
}
// checkIdentifierParams();
String outputName = getTempName();
DataIdentifier output = new DataIdentifier(outputName);
output.setParseInfo(this);
if (getFirstExpr() == null && !isValidNoArgumentFunction()) { // time has no arguments
raiseValidateError("Function " + this + " has no arguments.", false);
}
Identifier id = (_args.length != 0) ?
getFirstExpr().getOutput() : null;
if (_args.length != 0)
output.setProperties(this.getFirstExpr().getOutput());
output.setNnz(-1); //conservatively, cannot use input nnz!
setOutput(output);
switch (getOpCode()) {
case EVAL:
case EVALLIST:
if (_args.length == 0)
raiseValidateError("Function eval should provide at least one argument, i.e., the function name.", false);
checkValueTypeParam(_args[0], ValueType.STRING);
boolean listReturn = (getOpCode()==Builtins.EVALLIST);
output.setDataType(listReturn ? DataType.LIST : DataType.MATRIX);
output.setValueType(listReturn ? ValueType.UNKNOWN : ValueType.FP64);
output.setDimensions(-1, -1);
output.setBlocksize(ConfigurationManager.getBlocksize());
break;
case COLSUM:
case COLMAX:
case COLMIN:
case COLMEAN:
case COLPROD:
case COLSD:
case COLVAR:
// colSums(X);
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(1, id.getDim2());
output.setBlocksize (id.getBlocksize());
output.setValueType(id.getValueType());
break;
case ROWSUM:
case ROWMAX:
case ROWINDEXMAX:
case ROWMIN:
case ROWINDEXMIN:
case ROWMEAN:
case ROWPROD:
case ROWSD:
case ROWVAR:
//rowSums(X);
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim1(), 1);
output.setBlocksize (id.getBlocksize());
output.setValueType(id.getValueType());
break;
case TRACE:
if(getFirstExpr().getOutput().dimsKnown()
&& getFirstExpr().getOutput().getDim1() != getFirstExpr().getOutput().getDim2())
{
raiseValidateError("Trace is only defined on squared matrices but found ["
+getFirstExpr().getOutput().getDim1()+"x"+getFirstExpr().getOutput().getDim2()+"].", conditional);
}
case SUM:
case PROD:
case SD:
case VAR:
// sum(X);
checkNumParameters(1);
checkMatrixTensorParam(getFirstExpr());
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlocksize(0);
switch (id.getValueType()) {
case INT64:
case INT32:
case UINT8:
case UINT4:
case BOOLEAN:
output.setValueType(ValueType.INT64);
break;
case STRING:
case CHARACTER:
case FP64:
case FP32:
case HASH32:
case HASH64: //default
output.setValueType(ValueType.FP64);
break;
case UNKNOWN:
throw new NotImplementedException();
}
break;
case MEAN:
//checkNumParameters(2, false); // mean(Y) or mean(Y,W)
if (getSecondExpr() != null) {
checkNumParameters(2);
}
else {
checkNumParameters(1);
}
checkMatrixParam(getFirstExpr());
if ( getSecondExpr() != null ) {
// x = mean(Y,W);
checkMatchingDimensions(getFirstExpr(), getSecondExpr());
}
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlocksize(0);
output.setValueType(id.getValueType());
break;
case XOR:
case BITWAND:
case BITWOR:
case BITWXOR:
case BITWSHIFTL:
case BITWSHIFTR:
checkNumParameters(2);
setBinaryOutputProperties(output);
break;
case MIN:
case MAX:
//min(X), min(X,s), min(s,X), min(s,r), min(X,Y)
if (getSecondExpr() == null) { //unary
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.SCALAR);
output.setValueType(id.getValueType());
output.setDimensions(0, 0);
output.setBlocksize(0);
}
else if( getAllExpr().length == 2 ) { //binary
checkNumParameters(2);
setBinaryOutputProperties(output);
}
else { //nary
for( Expression e : getAllExpr() )
checkMatrixScalarParam(e);
setNaryOutputProperties(output);
}
break;
case CUMSUM:
case CUMPROD:
case CUMSUMPROD:
case CUMMIN:
case CUMMAX:
// cumsum(X);
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
boolean cumSP = getOpCode() == Builtins.CUMSUMPROD;
if( cumSP && id.getDim2() > 2 )
raiseValidateError("Cumsumprod only supported over two-column matrices", conditional);
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim1(), cumSP ? 1 : id.getDim2());
output.setBlocksize (id.getBlocksize());
output.setValueType(id.getValueType());
break;
case CAST_AS_SCALAR:
checkNumParameters(1);
checkDataTypeParam(getFirstExpr(),
DataType.MATRIX, DataType.FRAME, DataType.LIST);
if (( getFirstExpr().getOutput().getDim1() != -1 && getFirstExpr().getOutput().getDim1() !=1)
|| ( getFirstExpr().getOutput().getDim2() != -1 && getFirstExpr().getOutput().getDim2() !=1)) {
raiseValidateError("dimension mismatch while casting matrix to scalar: dim1: " + getFirstExpr().getOutput().getDim1()
+ " dim2 " + getFirstExpr().getOutput().getDim2(), conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlocksize(0);
output.setValueType((id.getValueType()!=ValueType.UNKNOWN
|| id.getDataType()==DataType.LIST) ? id.getValueType() : ValueType.FP64);
break;
case CAST_AS_MATRIX:
checkNumParameters(1);
checkDataTypeParam(getFirstExpr(),
DataType.SCALAR, DataType.FRAME, DataType.LIST);
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim1(), id.getDim2());
if( getFirstExpr().getOutput().getDataType()==DataType.SCALAR )
output.setDimensions(1, 1); //correction scalars
if( getFirstExpr().getOutput().getDataType()==DataType.LIST )
output.setDimensions(-1, -1); //correction list: arbitrary object
output.setBlocksize(id.getBlocksize());
output.setValueType(ValueType.FP64); //matrices always in double
break;
case CAST_AS_LIST: //list unnesting
checkNumParameters(1);
checkDataTypeParam(getFirstExpr(), DataType.LIST);
output.setDataType(DataType.LIST);
output.setDimensions(-1, 1);
output.setBlocksize(id.getBlocksize());
output.setValueType(ValueType.UNKNOWN);
break;
case TYPEOF:
case DETECTSCHEMA:
case COLNAMES:
checkNumParameters(1);
checkMatrixFrameParam(getFirstExpr());
output.setDataType(DataType.FRAME);
output.setDimensions(1, id.getDim2());
output.setBlocksize (id.getBlocksize());
output.setValueType(ValueType.STRING);
break;
case CAST_AS_FRAME:
// operation as.frame
// overloaded to take either one argument or 2 where second is column names
if( getSecondExpr() == null) {// there is no column names
checkNumParameters(1);
}
else{ // there is column names
checkNumParameters(2);
checkDataTypeParam(getSecondExpr(), DataType.LIST);
}
checkDataTypeParam(getFirstExpr(), DataType.SCALAR, DataType.MATRIX, DataType.LIST);
output.setDataType(DataType.FRAME);
output.setDimensions(id.getDim1(), id.getDim2());
if(getFirstExpr().getOutput().getDataType() == DataType.SCALAR)
output.setDimensions(1, 1); // correction scalars
if(getFirstExpr().getOutput().getDataType() == DataType.LIST)
output.setDimensions(-1, -1); // correction list: arbitrary object
output.setBlocksize(id.getBlocksize());
output.setValueType(id.getValueType());
break;
case CAST_AS_DOUBLE:
checkNumParameters(1);
checkScalarParam(getFirstExpr());
output.setDataType(DataType.SCALAR);
//output.setDataType(id.getDataType()); //TODO whenever we support multiple matrix value types, currently noop.
output.setDimensions(0, 0);
output.setBlocksize(0);
output.setValueType(ValueType.FP64);
break;
case CAST_AS_INT:
checkNumParameters(1);
checkScalarParam(getFirstExpr());
output.setDataType(DataType.SCALAR);
//output.setDataType(id.getDataType()); //TODO whenever we support multiple matrix value types, currently noop.
output.setDimensions(0, 0);
output.setBlocksize(0);
output.setValueType(ValueType.INT64);
break;
case CAST_AS_BOOLEAN:
checkNumParameters(1);
checkScalarParam(getFirstExpr());
output.setDataType(DataType.SCALAR);
//output.setDataType(id.getDataType()); //TODO whenever we support multiple matrix value types, currently noop.
output.setDimensions(0, 0);
output.setBlocksize(0);
output.setValueType(ValueType.BOOLEAN);
break;
case IFELSE:
checkNumParameters(3);
setTernaryOutputProperties(output, conditional);
break;
case CBIND:
case RBIND:
//scalar string append (string concatenation with \n)
if( getFirstExpr().getOutput().getDataType()==DataType.SCALAR ) {
checkNumParameters(2);
checkScalarParam(getFirstExpr());
checkScalarParam(getSecondExpr());
checkValueTypeParam(getFirstExpr(), ValueType.STRING);
checkValueTypeParam(getSecondExpr(), ValueType.STRING);
}
// append (rbind/cbind) all the elements of a list
else if( getAllExpr().length == 1 ) {
checkDataTypeParam(getFirstExpr(), DataType.LIST);
}
else {
if( getAllExpr().length < 2 )
raiseValidateError("Invalid number of arguments for "+getOpCode(), conditional);
//list append
if(getFirstExpr().getOutput().getDataType().isList() )
for(int i=1; i<getAllExpr().length; i++)
checkDataTypeParam(getExpr(i), DataType.SCALAR, DataType.MATRIX, DataType.FRAME, DataType.LIST);
//matrix append (rbind/cbind)
else
for(int i=0; i<getAllExpr().length; i++)
checkMatrixFrameParam(getExpr(i));
}
output.setDataType(id.getDataType());
output.setValueType(id.getValueType());
//special handling of concatenating all list elements
if( id.getDataType() == DataType.LIST && getAllExpr().length == 1) {
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.FP64);
}
// set output dimensions and validate consistency
long m1rlen = getFirstExpr().getOutput().getDim1();
long m1clen = getFirstExpr().getOutput().getDim2();
long appendDim1 = m1rlen, appendDim2 = m1clen;
// best-effort dimension propagation and validation
if( id.getDataType() == DataType.LIST ) {
appendDim1 = -1;
appendDim2 = -1;
}
else {
for(int i=1; i<getAllExpr().length; i++) {
long m2rlen = getExpr(i).getOutput().getDim1();
long m2clen = getExpr(i).getOutput().getDim2();
if( getOpCode() == Builtins.CBIND ) {
if (m1rlen >= 0 && m2rlen >= 0 && m1rlen!=m2rlen) {
raiseValidateError("inputs to cbind must have same number of rows: input 1 rows: " +
m1rlen+", input 2 rows: "+m2rlen, conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
appendDim1 = (m2rlen>=0) ? m2rlen : appendDim1;
appendDim2 = (appendDim2>=0 && m2clen>=0) ? appendDim2 + m2clen : -1;
}
else if( getOpCode() == Builtins.RBIND ) {
if (m1clen >= 0 && m2clen >= 0 && m1clen!=m2clen) {
raiseValidateError("inputs to rbind must have same number of columns: input 1 columns: " +
m1clen+", input 2 columns: "+m2clen, conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
appendDim1 = (appendDim1>=0 && m2rlen>=0)? appendDim1 + m2rlen : -1;
appendDim2 = (m2clen>=0) ? m2clen : appendDim2;
}
}
}
output.setDimensions(appendDim1, appendDim2);
output.setBlocksize (id.getBlocksize());
break;
case PPRED:
// TODO: remove this when ppred has been removed from DML
raiseValidateError("ppred() has been deprecated. Please use the operator directly.", true);
// ppred (X,Y, "<"); ppred (X,y, "<"); ppred (y,X, "<");
checkNumParameters(3);
DataType dt1 = getFirstExpr().getOutput().getDataType();
DataType dt2 = getSecondExpr().getOutput().getDataType();
//check input data types
if( dt1 == DataType.SCALAR && dt2 == DataType.SCALAR ) {
raiseValidateError("ppred() requires at least one matrix input.", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
if( dt1 == DataType.MATRIX )
checkMatrixParam(getFirstExpr());
if( dt2 == DataType.MATRIX )
checkMatrixParam(getSecondExpr());
//check operator
if (getThirdExpr().getOutput().getDataType() != DataType.SCALAR ||
getThirdExpr().getOutput().getValueType() != ValueType.STRING)
{
raiseValidateError("Third argument in ppred() is not an operator ", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
setBinaryOutputProperties(output);
break;
case TRANS:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim2(), id.getDim1());
output.setBlocksize (id.getBlocksize());
output.setValueType(id.getValueType());
break;
case REV:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim1(), id.getDim2());
output.setBlocksize (id.getBlocksize());
output.setValueType(id.getValueType());
break;
case ROLL:
checkNumParameters(2);
checkMatrixParam(getFirstExpr());
checkScalarParam(getSecondExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim1(), id.getDim2());
output.setBlocksize(id.getBlocksize());
output.setValueType(id.getValueType());
break;
case DIAG:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
if( id.getDim2() != -1 ) { //type known
if ( id.getDim2() == 1 )
{
//diag V2M
output.setDimensions(id.getDim1(), id.getDim1());
}
else
{
if (id.getDim1() != id.getDim2()) {
raiseValidateError("diag can either: (1) create diagonal matrix from (n x 1) matrix, or (2) take diagonal from a square matrix. "
+ "Error invoking diag on matrix with dimensions ("
+ id.getDim1() + "," + id.getDim2()
+ ") in " + this.toString(), conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
//diag M2V
output.setDimensions(id.getDim1(), 1);
}
}
output.setBlocksize(id.getBlocksize());
output.setValueType(id.getValueType());
break;
case DET:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
if ( id.getDim2() == -1 || id.getDim1() != id.getDim2() ) {
raiseValidateError("det requires a square matrix as first argument.", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlocksize(0);
output.setValueType(ValueType.FP64);
break;
case NROW:
case NCOL:
case LENGTH:
checkNumParameters(1);
checkDataTypeParam(getFirstExpr(),
DataType.FRAME, DataType.LIST, DataType.MATRIX);
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlocksize(0);
output.setValueType(ValueType.INT64);
break;
case LINEAGE:
checkNumParameters(1);
checkDataTypeParam(getFirstExpr(),
DataType.MATRIX, DataType.FRAME, DataType.LIST);
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlocksize(0);
output.setValueType(ValueType.STRING);
break;
case LIST:
output.setDataType(DataType.LIST);
output.setValueType(ValueType.UNKNOWN);
output.setDimensions(getAllExpr().length, 1);
output.setBlocksize(-1);
break;
case EXISTS:
checkNumParameters(1);
checkStringOrDataIdentifier(getFirstExpr());
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlocksize(0);
output.setValueType(ValueType.BOOLEAN);
break;
// Contingency tables
case TABLE:
/*
* Allowed #of arguments: 2,3,4,5,6
* table(A,B)
* table(A,B,W)
* table(A,B,1)
* table(A,B,dim1,dim2)
* table(A,B,W,dim1,dim2)
* table(A,B,1,dim1,dim2)
* table(A,B,1,dim1,dim2,TRUE)
*/
// Check for validity of input arguments, and setup output dimensions
// First input: is always of type MATRIX
checkMatrixParam(getFirstExpr());
if (getSecondExpr() == null)
raiseValidateError("Invalid number of arguments to table(). "
+ "The table() function requires 2, 3, 4, 5, or 6 arguments.", conditional);
// Second input: can be MATRIX or SCALAR
// cases: table(A,B) or table(A,1)
if ( getSecondExpr().getOutput().getDataType() == DataType.MATRIX)
checkMatchingDimensions(getFirstExpr(),getSecondExpr());
long outputDim1=-1, outputDim2=-1;
switch(_args.length) {
case 2:
// nothing to do
break;
case 3:
// case - table w/ weights
// - weights specified as a matrix: table(A,B,W) or table(A,1,W)
// - weights specified as a scalar: table(A,B,1) or table(A,1,1)
if ( getThirdExpr().getOutput().getDataType() == DataType.MATRIX)
checkMatchingDimensions(getFirstExpr(),getThirdExpr());
break;
case 4:
// case - table w/ output dimensions: table(A,B,dim1,dim2) or table(A,1,dim1,dim2)
// third and fourth arguments must be scalars
if ( getThirdExpr().getOutput().getDataType() != DataType.SCALAR || _args[3].getOutput().getDataType() != DataType.SCALAR ) {
raiseValidateError("Invalid argument types to table(): output dimensions must be of type scalar: "
+ this.toString(), conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
else {
// constant propagation
if( getThirdExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getThirdExpr()).getName()) && !conditional )
_args[2] = constVars.get(((DataIdentifier)getThirdExpr()).getName());
if( _args[3] instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)_args[3]).getName()) && !conditional )
_args[3] = constVars.get(((DataIdentifier)_args[3]).getName());
if ( getThirdExpr().getOutput() instanceof ConstIdentifier )
outputDim1 = ((ConstIdentifier) getThirdExpr().getOutput()).getLongValue();
if ( _args[3].getOutput() instanceof ConstIdentifier )
outputDim2 = ((ConstIdentifier) _args[3].getOutput()).getLongValue();
}
break;
case 5:
case 6:
// case - table w/ weights and output dimensions:
// - table(A,B,W,dim1,dim2) or table(A,1,W,dim1,dim2)
// - table(A,B,1,dim1,dim2) or table(A,1,1,dim1,dim2)
if ( getThirdExpr().getOutput().getDataType() == DataType.MATRIX)
checkMatchingDimensions(getFirstExpr(),getThirdExpr());
// fourth and fifth arguments must be scalars
if ( _args[3].getOutput().getDataType() != DataType.SCALAR || _args[4].getOutput().getDataType() != DataType.SCALAR ) {
raiseValidateError("Invalid argument types to table(): output dimensions must be of type scalar: "
+ this.toString(), conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
else {
// constant propagation
if( _args[3] instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)_args[3]).getName()) && !conditional )
_args[3] = constVars.get(((DataIdentifier)_args[3]).getName());
if( _args[4] instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)_args[4]).getName()) && !conditional )
_args[4] = constVars.get(((DataIdentifier)_args[4]).getName());
if ( _args[3].getOutput() instanceof ConstIdentifier )
outputDim1 = ((ConstIdentifier) _args[3].getOutput()).getLongValue();
if ( _args[4].getOutput() instanceof ConstIdentifier )
outputDim2 = ((ConstIdentifier) _args[4].getOutput()).getLongValue();
}
if( _args.length == 6 ) {
if( !_args[5].getOutput().isScalarBoolean() )
raiseValidateError("The 6th ctable parameter (outputEmptyBlocks) must be a boolean literal.", conditional);
}
break;
default:
raiseValidateError("Invalid number of arguments to table(): "
+ this.toString(), conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
// The dimensions for the output matrix will be known only at the
// run time
output.setDimensions(outputDim1, outputDim2);
output.setBlocksize (-1);
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.FP64);
break;
case MOMENT:
checkMatrixParam(getFirstExpr());
if (getThirdExpr() != null) {
checkNumParameters(3);
checkMatrixParam(getSecondExpr());
checkMatchingDimensions(getFirstExpr(),getSecondExpr());
checkScalarParam(getThirdExpr());
}
else {
checkNumParameters(2);
checkScalarParam(getSecondExpr());
}
// output is a scalar
output.setDataType(DataType.SCALAR);
output.setValueType(ValueType.FP64);
output.setDimensions(0, 0);
output.setBlocksize(0);
break;
case COV:
/*
* x = cov(V1,V2) or xw = cov(V1,V2,W)
*/
if (getThirdExpr() != null) {
checkNumParameters(3);
}
else {
checkNumParameters(2);
}
checkMatrixParam(getFirstExpr());
checkMatrixParam(getSecondExpr());
checkMatchingDimensions(getFirstExpr(),getSecondExpr());
if (getThirdExpr() != null) {
checkMatrixParam(getThirdExpr());
checkMatchingDimensions(getFirstExpr(), getThirdExpr());
}
// output is a scalar
output.setDataType(DataType.SCALAR);
output.setValueType(ValueType.FP64);
output.setDimensions(0, 0);
output.setBlocksize(0);
break;
case QUANTILE:
/*
* q = quantile(V1,0.5) computes median in V1
* or Q = quantile(V1,P) computes the vector of quantiles as specified by P
* or qw = quantile(V1,W,0.5) computes median when weights (W) are given
* or QW = quantile(V1,W,P) computes the vector of quantiles as specified by P, when weights (W) are given
*/
if(getThirdExpr() != null) {
checkNumParameters(3);
}
else {
checkNumParameters(2);
}
// first parameter must always be a 1D matrix
check1DMatrixParam(getFirstExpr());
// check for matching dimensions for other matrix parameters
if (getThirdExpr() != null) {
checkMatrixParam(getSecondExpr());
checkMatchingDimensions(getFirstExpr(), getSecondExpr());
}
// set the properties for _output expression
// output dimensions = dimensions of second, if third is null
// = dimensions of the third, otherwise.
if (getThirdExpr() != null) {
output.setDimensions(getThirdExpr().getOutput().getDim1(), getThirdExpr().getOutput().getDim2());
output.setBlocksize(getThirdExpr().getOutput().getBlocksize());
output.setDataType(getThirdExpr().getOutput().getDataType());
} else {
output.setDimensions(getSecondExpr().getOutput().getDim1(), getSecondExpr().getOutput().getDim2());
output.setBlocksize(getSecondExpr().getOutput().getBlocksize());
output.setDataType(getSecondExpr().getOutput().getDataType());
}
break;
case INTERQUANTILE:
if (getThirdExpr() != null) {
checkNumParameters(3);
}
else {
checkNumParameters(2);
}
checkMatrixParam(getFirstExpr());
if (getThirdExpr() != null) {
// i.e., second input is weight vector
checkMatrixParam(getSecondExpr());
checkMatchingDimensionsQuantile();
}
if ((getThirdExpr() == null && getSecondExpr().getOutput().getDataType() != DataType.SCALAR)
&& (getThirdExpr() != null && getThirdExpr().getOutput().getDataType() != DataType.SCALAR)) {
raiseValidateError("Invalid parameters to "+ this.getOpCode(), conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
output.setValueType(id.getValueType());
// output dimensions are unknown
output.setDimensions(-1, -1);
output.setBlocksize(-1);
output.setDataType(DataType.MATRIX);
break;
case IQM:
/*
* Usage: iqm = InterQuartileMean(A,W); iqm = InterQuartileMean(A);
*/
if (getSecondExpr() != null){
checkNumParameters(2);
}
else {
checkNumParameters(1);
}
checkMatrixParam(getFirstExpr());
if (getSecondExpr() != null) {
// i.e., second input is weight vector
checkMatrixParam(getSecondExpr());
checkMatchingDimensions(getFirstExpr(), getSecondExpr());
}
// Output is a scalar
output.setValueType(id.getValueType());
output.setDimensions(0, 0);
output.setBlocksize(0);
output.setDataType(DataType.SCALAR);
break;
case ISNA:
case ISNAN:
case ISINF:
checkNumParameters(1);
checkMatrixScalarParam(getFirstExpr());
output.setDataType(id.getDataType());
output.setDimensions(id.getDim1(), id.getDim2());
output.setBlocksize (id.getBlocksize());
//TODO set output type to boolean when supported
output.setValueType(id.getValueType());
break;
case MEDIAN:
checkNumParameters((getSecondExpr()!=null) ? 2 : 1);
checkMatrixParam(getFirstExpr());
if (getSecondExpr() != null) {
// i.e., second input is weight vector
checkMatrixParam(getSecondExpr());
checkMatchingDimensions(getFirstExpr(), getSecondExpr());
}
// Output is a scalar
output.setValueType(id.getValueType());
output.setDimensions(0, 0);
output.setBlocksize(0);
output.setDataType(DataType.SCALAR);
break;
case SAMPLE:
{
Expression[] in = getAllExpr();
for(Expression e : in)
checkScalarParam(e);
if (in[0].getOutput().getValueType() != ValueType.FP64 && in[0].getOutput().getValueType() != ValueType.INT64)
throw new LanguageException("First argument to sample() must be a number.");
if (in[1].getOutput().getValueType() != ValueType.FP64 && in[1].getOutput().getValueType() != ValueType.INT64)
throw new LanguageException("Second argument to sample() must be a number.");
boolean check = false;
if ( isConstant(in[0]) && isConstant(in[1]) )
{
long range = ((ConstIdentifier)in[0]).getLongValue();
long size = ((ConstIdentifier)in[1]).getLongValue();
if ( range < size )
check = true;
}
if(in.length == 4 )
{
checkNumParameters(4);
if (in[3].getOutput().getValueType() != ValueType.INT64)
throw new LanguageException("Fourth argument, seed, to sample() must be an integer value.");
if (in[2].getOutput().getValueType() != ValueType.BOOLEAN )
throw new LanguageException("Third argument to sample() must either denote replacement policy (boolean) or seed (integer).");
}
else if(in.length == 3)
{
checkNumParameters(3);
if (in[2].getOutput().getValueType() != ValueType.BOOLEAN
&& in[2].getOutput().getValueType() != ValueType.INT64 )
throw new LanguageException("Third argument to sample() must either denote replacement policy (boolean) or seed (integer).");
}
if ( check && in.length >= 3
&& isConstant(in[2])
&& in[2].getOutput().getValueType() == ValueType.BOOLEAN
&& !((BooleanIdentifier)in[2]).getValue() )
throw new LanguageException("Sample (size=" + ((ConstIdentifier)in[0]).getLongValue()
+ ") larger than population (size=" + ((ConstIdentifier)in[1]).getLongValue()
+ ") can only be generated with replacement.");
// Output is a column vector
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.FP64);
if ( isConstant(in[1]) )
output.setDimensions(((ConstIdentifier)in[1]).getLongValue(), 1);
else
output.setDimensions(-1, 1);
setBlocksize(id.getBlocksize());
break;
}
case SEQ:
//basic parameter validation
checkScalarParam(getFirstExpr());
checkScalarParam(getSecondExpr());
if ( getThirdExpr() != null ) {
checkNumParameters(3);
checkScalarParam(getThirdExpr());
}
else
checkNumParameters(2);
// constant propagation (from, to, incr)
if( !conditional ) {
if( getFirstExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getFirstExpr()).getName()) )
_args[0] = constVars.get(((DataIdentifier)getFirstExpr()).getName());
if( getSecondExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getSecondExpr()).getName()) )
_args[1] = constVars.get(((DataIdentifier)getSecondExpr()).getName());
if( getThirdExpr()!=null && getThirdExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getThirdExpr()).getName()) )
_args[2] = constVars.get(((DataIdentifier)getThirdExpr()).getName());
}
// check if dimensions can be inferred
long dim1=-1, dim2=1;
if ( isConstant(getFirstExpr()) && isConstant(getSecondExpr()) && (getThirdExpr() != null ? isConstant(getThirdExpr()) : true) ) {
double from, to, incr;
try {
from = getDoubleValue(getFirstExpr());
to = getDoubleValue(getSecondExpr());
// Setup the value of increment
// default value: 1 if from <= to; -1 if from > to
if(getThirdExpr() == null) {
expandArguments();
_args[2] = new DoubleIdentifier(((from > to) ? -1.0 : 1.0), this);
}
incr = getDoubleValue(getThirdExpr());
}
catch (LanguageException e) {
throw new LanguageException("Arguments for seq() must be numeric.");
}
if( (from > to) && (incr >= 0) )
throw new LanguageException("Wrong sign for the increment in a call to seq()");
// Both end points of the range must included i.e., [from,to] both inclusive.
// Note that, "to" is included only if (to-from) is perfectly divisible by incr
// For example, seq(0,1,0.5) produces (0.0 0.5 1.0) whereas seq(0,1,0.6) produces only (0.0 0.6) but not (0.0 0.6 1.0)
dim1 = UtilFunctions.getSeqLength(from, to, incr);
}
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.FP64);
output.setDimensions(dim1, dim2);
output.setBlocksize(0);
break;
case SOLVE:
checkNumParameters(2);
checkMatrixParam(getFirstExpr());
checkMatrixParam(getSecondExpr());
if ( getSecondExpr().getOutput().dimsKnown() && !is1DMatrix(getSecondExpr()) )
raiseValidateError("Second input to solve() must be a vector", conditional);
if ( getFirstExpr().getOutput().dimsKnown() && getSecondExpr().getOutput().dimsKnown() &&
getFirstExpr().getOutput().getDim1() != getSecondExpr().getOutput().getDim1() &&
getFirstExpr().getOutput().getDim1() != getFirstExpr().getOutput().getDim2())
raiseValidateError("Dimension mismatch in a call to solve()", conditional);
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.FP64);
output.setDimensions(getFirstExpr().getOutput().getDim2(), 1);
output.setBlocksize(0);
break;
case INVERSE:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.FP64);
Identifier in = getFirstExpr().getOutput();
if(in.dimsKnown() && in.getDim1() != in.getDim2())
raiseValidateError("Input to inv() must be square matrix -- given: a " + in.getDim1() + "x" + in.getDim2() + " matrix.", conditional);
output.setDimensions(in.getDim1(), in.getDim2());
output.setBlocksize(in.getBlocksize());
break;
case SQRT_MATRIX_JAVA:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.FP64);
Identifier sqrt = getFirstExpr().getOutput();
if(sqrt.dimsKnown() && sqrt.getDim1() != sqrt.getDim2())
raiseValidateError("Input to sqrtMatrix() must be square matrix -- given: a " + sqrt.getDim1() + "x" + sqrt.getDim2() + " matrix.", conditional);
output.setDimensions( sqrt.getDim1(), sqrt.getDim2());
output.setBlocksize( sqrt.getBlocksize());
break;
case CHOLESKY:
{
// A = L%*%t(L) where L is the lower triangular matrix
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.FP64);
Identifier inA = getFirstExpr().getOutput();
if(inA.dimsKnown() && inA.getDim1() != inA.getDim2())
raiseValidateError("Input to cholesky() must be square matrix -- given: a " + inA.getDim1() + "x" + inA.getDim2() + " matrix.", conditional);
output.setDimensions(inA.getDim1(), inA.getDim2());
output.setBlocksize(inA.getBlocksize());
break;
}
case OUTER:
Identifier id2 = this.getSecondExpr().getOutput();
//check input types and characteristics
checkNumParameters(3);
checkMatrixParam(getFirstExpr());
checkMatrixParam(getSecondExpr());
checkScalarParam(getThirdExpr());
checkValueTypeParam(getThirdExpr(), ValueType.STRING);
if( id.getDim2() > 1 || id2.getDim1()>1 ) {
raiseValidateError("Outer vector operations require a common dimension of one: " +
id.getDim1()+"x"+id.getDim2()+" o "+id2.getDim1()+"x"+id2.getDim2()+".", false);
}
//set output characteristics
output.setDataType(id.getDataType());
output.setDimensions(id.getDim1(), id2.getDim2());
output.setBlocksize(id.getBlocksize());
break;
case BIASADD:
case BIASMULT:
{
Expression input = _args[0];
Expression bias = _args[1];
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.FP64);
output.setDimensions(input.getOutput().getDim1(), input.getOutput().getDim2());
output.setBlocksize(input.getOutput().getBlocksize());
checkMatrixParam(input);
checkMatrixParam(bias);
break;
}
case CONV2D:
case CONV2D_BACKWARD_FILTER:
case CONV2D_BACKWARD_DATA:
case MAX_POOL:
case AVG_POOL:
case MAX_POOL_BACKWARD:
case AVG_POOL_BACKWARD:
{
// At DML level:
// output = conv2d(input, filter, input_shape=[1, 3, 2, 2], filter_shape=[1, 3, 2, 2],
// strides=[1, 1], padding=[1,1])
//
// Converted to following in constructor (only supported NCHW):
// output = conv2d(input, filter, stride1, stride2, padding1,padding2,
// input_shape1, input_shape2, input_shape3, input_shape4,
// filter_shape1, filter_shape2, filter_shape3, filter_shape4)
//
// Similarly,
// conv2d_backward_filter and conv2d_backward_data
Expression input = _args[0]; // For conv2d_backward_filter, this is input and for conv2d_backward_data, this is filter
Expression input2 = null;
if(!(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AVG_POOL)) {
input2 = _args[1]; // For conv2d_backward functions, this is dout
checkMatrixParam(input2);
}
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.FP64);
output.setBlocksize(input.getOutput().getBlocksize());
if(this.getOpCode() == Builtins.MAX_POOL_BACKWARD || this.getOpCode() == Builtins.AVG_POOL_BACKWARD) {
output.setDimensions(input.getOutput().getDim1(), input.getOutput().getDim2());
}
else {
// stride1, stride2, padding1, padding2, numImg, numChannels, imgSize, imgSize,
// filter_shape1=1, filter_shape2=1, filterSize/poolSize1, filterSize/poolSize1
try {
int start = 2;
if(!(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AVG_POOL)) {
start = 1;
}
long stride_h = (long) getDoubleValue(_args[start++]);
long stride_w = (long) getDoubleValue(_args[start++]);
long pad_h = (long) getDoubleValue(_args[start++]);
long pad_w = (long) getDoubleValue(_args[start++]);
long N = (long) getDoubleValue(_args[start++]);
long C = (long) getDoubleValue(_args[start++]);
long H = (long) getDoubleValue(_args[start++]);
long W = (long) getDoubleValue(_args[start++]);
long K = -1;
if(!(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AVG_POOL)) {
K = (long) getDoubleValue(_args[start]);
}
start++; start++; // Increment index for K and C
long R = (long) getDoubleValue(_args[start++]);
long S = (long) getDoubleValue(_args[start++]);
if(this.getOpCode() == Builtins.CONV2D_BACKWARD_FILTER) {
output.setDimensions(K, C*R*S);
}
else if(this.getOpCode() == Builtins.CONV2D_BACKWARD_DATA) {
output.setDimensions(N, C*H*W);
}
else if(H > 0 && W > 0 && stride_h > 0 && stride_w > 0 && pad_h >= 0 && pad_w >= 0 && R > 0 && S > 0) {
long P = DnnUtils.getP(H, R, stride_h, pad_h);
long Q = DnnUtils.getQ(W, S, stride_w, pad_w);
// Try to set both rows and columns
if(this.getOpCode() == Builtins.CONV2D)
output.setDimensions(N, K*P*Q);
else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AVG_POOL)
output.setDimensions(N, C*P*Q);
else
throw new LanguageException("");
}
else {
// Since columns cannot be computed, set only rows
if(this.getOpCode() == Builtins.CONV2D)
output.setDimensions(input.getOutput().getDim1(), -1);
else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AVG_POOL)
output.setDimensions(input.getOutput().getDim1(), -1);
else
throw new LanguageException("");
}
}
catch(Exception e) {
output.setDimensions(-1, -1); // To make sure that output dimensions are not incorrect even if getDoubleValue doesnot return value
}
}
checkMatrixParam(input);
if(input2 != null)
checkMatrixParam(input2);
break;
}
case TIME:
checkNumParameters(0);
// Output of TIME() is scalar and long
output.setDataType(DataType.SCALAR);
output.setValueType(ValueType.INT64);
output.setDimensions(0, 0);
output.setBlocksize(0);
break;
case DROP_INVALID_TYPE:
case VALUE_SWAP:
case FRAME_ROW_REPLICATE:
checkNumParameters(2);
checkMatrixFrameParam(getFirstExpr());
checkMatrixFrameParam(getSecondExpr());
output.setDataType(DataType.FRAME);
output.setDimensions(id.getDim1(), id.getDim2());
output.setBlocksize (id.getBlocksize());
output.setValueType(ValueType.STRING);
break;
case DROP_INVALID_LENGTH:
checkNumParameters(2);
checkMatrixFrameParam(getFirstExpr());
checkMatrixFrameParam(getSecondExpr());
output.setDataType(DataType.FRAME);
output.setDimensions(id.getDim1(), id.getDim2());
output.setBlocksize (id.getBlocksize());
output.setValueType(id.getValueType());
break;
case APPLY_SCHEMA:
checkNumParameters(2);
checkMatrixFrameParam(getFirstExpr());
checkMatrixFrameParam(getSecondExpr());
output.setDataType(DataType.FRAME);
output.setDimensions(id.getDim1(), id.getDim2());
output.setBlocksize (id.getBlocksize());
break;
case MAP:
checkNumParameters(getThirdExpr() != null ? 3 : 2);
checkMatrixFrameParam(getFirstExpr());
checkScalarParam(getSecondExpr());
if(getThirdExpr() != null)
checkScalarParam(getThirdExpr()); // margin
output.setDataType(DataType.FRAME);
if(_args[1].getText().contains("jaccardSim")) {
output.setDimensions(id.getDim1(), id.getDim1());
output.setValueType(ValueType.FP64);
}
else {
output.setDimensions(id.getDim1(), id.getDim2());
output.setValueType(ValueType.STRING);
}
break;
case LOCAL:
if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_LOCAL_COMMAND){
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim1(), id.getDim2());
output.setBlocksize (id.getBlocksize());
output.setValueType(id.getValueType());
}
else
raiseValidateError("Local instruction not allowed in dml script");
case COMPRESS:
case DECOMPRESS:
if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND){
checkNumParameters(1);
checkMatrixFrameParam(getFirstExpr());
output.setDataType(getFirstExpr().getOutput().getDataType());
output.setDimensions(id.getDim1(), id.getDim2());
output.setBlocksize (id.getBlocksize());
output.setValueType(id.getValueType());
}
else
raiseValidateError("Compress/DeCompress instruction not allowed in dml script");
break;
case ROW_COUNT_DISTINCT:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim1(), 1);
output.setBlocksize (id.getBlocksize());
output.setValueType(ValueType.INT64);
output.setNnz(id.getDim1());
break;
case COL_COUNT_DISTINCT:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(1, id.getDim2());
output.setBlocksize (id.getBlocksize());
output.setValueType(ValueType.INT64);
output.setNnz(id.getDim2());
break;
default:
if( isMathFunction() ) {
checkMathFunctionParam();
//unary operations
if( getSecondExpr() == null ) {
output.setDataType(id.getDataType());
output.setValueType((output.getDataType()==DataType.SCALAR
&& getOpCode()==Builtins.ABS)?id.getValueType():ValueType.FP64 );
output.setDimensions(id.getDim1(), id.getDim2());
output.setBlocksize(id.getBlocksize());
}
//binary operations
else {
setBinaryOutputProperties(output);
// override computed value type for special cases
if( getOpCode() == Builtins.LOG )
output.setValueType(ValueType.FP64);
}
}
else {
// always unconditional (because unsupported operation)
Builtins op = getOpCode();
if( op==Builtins.EIGEN || op==Builtins.LU || op==Builtins.QR || op==Builtins.SVD
|| op==Builtins.LSTM || op==Builtins.LSTM_BACKWARD
|| op==Builtins.BATCH_NORM2D || op==Builtins.BATCH_NORM2D_BACKWARD)
raiseValidateError("Function "+op+" needs to be called with multi-return assignment.", false, LanguageErrorCodes.INVALID_PARAMETERS);
else
raiseValidateError("Unsupported function "+op, false, LanguageErrorCodes.INVALID_PARAMETERS);
}
}
}