in src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java [299:417]
public void setOutputDims()
{
switch(_type) {
//VECT
case VECT_MULT_ADD:
case VECT_DIV_ADD:
case VECT_MINUS_ADD:
case VECT_PLUS_ADD:
case VECT_POW_ADD:
case VECT_MIN_ADD:
case VECT_MAX_ADD:
case VECT_EQUAL_ADD:
case VECT_NOTEQUAL_ADD:
case VECT_LESS_ADD:
case VECT_LESSEQUAL_ADD:
case VECT_GREATER_ADD:
case VECT_GREATEREQUAL_ADD:
case VECT_CBIND_ADD:
case VECT_XOR_ADD:
boolean vectorScalar = _inputs.get(1).getDataType()==DataType.SCALAR;
_rows = _inputs.get(vectorScalar ? 0 : 1)._rows;
_cols = _inputs.get(vectorScalar ? 0 : 1)._cols;
_dataType = DataType.MATRIX;
break;
case VECT_CBIND:
_rows = _inputs.get(0)._rows;
_cols = _inputs.get(0)._cols+1;
_dataType = DataType.MATRIX;
break;
case VECT_OUTERMULT_ADD:
_rows = _inputs.get(0)._cols;
_cols = _inputs.get(1)._cols;
_dataType = DataType.MATRIX;
break;
case VECT_DIV_SCALAR:
case VECT_MULT_SCALAR:
case VECT_MINUS_SCALAR:
case VECT_PLUS_SCALAR:
case VECT_XOR_SCALAR:
case VECT_BITWAND_SCALAR:
case VECT_POW_SCALAR:
case VECT_MIN_SCALAR:
case VECT_MAX_SCALAR:
case VECT_EQUAL_SCALAR:
case VECT_NOTEQUAL_SCALAR:
case VECT_LESS_SCALAR:
case VECT_LESSEQUAL_SCALAR:
case VECT_GREATER_SCALAR:
case VECT_GREATEREQUAL_SCALAR:
case VECT_DIV:
case VECT_MULT:
case VECT_MINUS:
case VECT_PLUS:
case VECT_XOR:
case VECT_BITWAND:
case VECT_MIN:
case VECT_MAX:
case VECT_EQUAL:
case VECT_NOTEQUAL:
case VECT_LESS:
case VECT_LESSEQUAL:
case VECT_GREATER:
case VECT_GREATEREQUAL:
case VECT_BIASADD:
case VECT_BIASMULT:
boolean scalarVector = (_inputs.get(0).getDataType()==DataType.SCALAR);
_rows = _inputs.get(scalarVector ? 1 : 0)._rows;
_cols = _inputs.get(scalarVector ? 1 : 0)._cols;
_dataType= DataType.MATRIX;
break;
case VECT_MATRIXMULT:
_rows = _inputs.get(0)._rows;
_cols = _inputs.get(1)._cols;
_dataType = DataType.MATRIX;
break;
case ROWMAXS_VECTMULT:
case DOT_PRODUCT:
//SCALAR Arithmetic
case MULT:
case DIV:
case PLUS:
case MINUS:
case MINUS1_MULT:
case MINUS_NZ:
case MODULUS:
case INTDIV:
//SCALAR Comparison
case LESS:
case LESSEQUAL:
case GREATER:
case GREATEREQUAL:
case EQUAL:
case NOTEQUAL:
//SCALAR LOGIC
case MIN:
case MAX:
case AND:
case OR:
case XOR:
case BITWAND:
case LOG:
case LOG_NZ:
case POW:
case SEQ_RIX:
_rows = 0;
_cols = 0;
_dataType= DataType.SCALAR;
break;
default:
throw new RuntimeException("Unknown CNodeBinary type: " + _type);
}
}