in R-package/src/ndarray.cc [557:709]
NDArray::RObjectType DispatchOps(SEXP op, SEXP lhs, SEXP rhs) {
// function handles
static OpHandle plus = NDArrayFunction::FindHandle("_plus");
static OpHandle plus_scalar = NDArrayFunction::FindHandle("_plus_scalar");
static OpHandle minus = NDArrayFunction::FindHandle("_minus");
static OpHandle minus_scalar = NDArrayFunction::FindHandle("_minus_scalar");
static OpHandle rminus_scalar = NDArrayFunction::FindHandle("_rminus_scalar");
static OpHandle mul = NDArrayFunction::FindHandle("_mul");
static OpHandle mul_scalar = NDArrayFunction::FindHandle("_mul_scalar");
static OpHandle div = NDArrayFunction::FindHandle("_div");
static OpHandle div_scalar = NDArrayFunction::FindHandle("_div_scalar");
static OpHandle rdiv_scalar = NDArrayFunction::FindHandle("_rdiv_scalar");
static OpHandle mod = NDArrayFunction::FindHandle("_mod");
static OpHandle mod_scalar = NDArrayFunction::FindHandle("_mod_scalar");
static OpHandle rmod_scalar = NDArrayFunction::FindHandle("_rmod_scalar");
static OpHandle equal = NDArrayFunction::FindHandle("_equal");
static OpHandle equal_scalar = NDArrayFunction::FindHandle("_equal_scalar");
static OpHandle not_equal = NDArrayFunction::FindHandle("_not_equal");
static OpHandle not_equal_scalar = NDArrayFunction::FindHandle("_not_equal_scalar");
static OpHandle greater = NDArrayFunction::FindHandle("_greater");
static OpHandle greater_scalar = NDArrayFunction::FindHandle("_greater_scalar");
static OpHandle greater_equal = NDArrayFunction::FindHandle("_greater_equal");
static OpHandle greater_equal_scalar = NDArrayFunction::FindHandle("_greater_equal_scalar");
static OpHandle lesser = NDArrayFunction::FindHandle("_lesser");
static OpHandle lesser_scalar = NDArrayFunction::FindHandle("_lesser_scalar");
static OpHandle lesser_equal = NDArrayFunction::FindHandle("_lesser_equal");
static OpHandle lesser_equal_scalar = NDArrayFunction::FindHandle("_lesser_equal_scalar");
// parse the arguments
std::string values[2];
NDArrayHandle handles[2];
NDArrayHandle out = nullptr;
bool lhs_nd = ParseNDArrayArg(lhs, &handles[0], &values[0]);
bool rhs_nd = ParseNDArrayArg(rhs, &handles[1], &values[1]);
RCHECK(lhs_nd || rhs_nd);
// create output and dispatch.
std::string sop = Rcpp::as<std::string>(op);
switch (sop[0]) {
case '+': {
if (lhs_nd && rhs_nd) {
out = BinaryOp(plus, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(plus_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(plus_scalar, handles[1], values[0]);
}
break;
}
case '-': {
if (lhs_nd && rhs_nd) {
out = BinaryOp(minus, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(minus_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(rminus_scalar, handles[1], values[0]);
}
break;
}
case '*': {
if (lhs_nd && rhs_nd) {
out = BinaryOp(mul, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(mul_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(mul_scalar, handles[1], values[0]);
}
break;
}
case '/': {
if (lhs_nd && rhs_nd) {
out = BinaryOp(div, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(div_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(rdiv_scalar, handles[1], values[0]);
}
break;
}
case '%': {
if (lhs_nd && rhs_nd) {
out = BinaryOp(mod, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(mod_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(rmod_scalar, handles[1], values[0]);
}
break;
}
case '=': {
if (lhs_nd && rhs_nd) {
out = BinaryOp(equal, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(equal_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(equal_scalar, handles[1], values[0]);
}
break;
}
case '!': {
if (lhs_nd && rhs_nd) {
out = BinaryOp(not_equal, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(not_equal_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(not_equal_scalar, handles[1], values[0]);
}
break;
}
case '>': {
if (sop == ">=") {
if (lhs_nd && rhs_nd) {
out = BinaryOp(greater_equal, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(greater_equal_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(lesser_equal_scalar, handles[1], values[0]);
}
} else {
if (lhs_nd && rhs_nd) {
out = BinaryOp(greater, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(greater_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(lesser_scalar, handles[1], values[0]);
}
}
break;
}
case '<': {
if (sop == "<=") {
if (lhs_nd && rhs_nd) {
out = BinaryOp(lesser_equal, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(lesser_equal_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(greater_equal_scalar, handles[1], values[0]);
}
} else {
if (lhs_nd && rhs_nd) {
out = BinaryOp(lesser, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(lesser_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(greater_scalar, handles[1], values[0]);
}
}
break;
}
default: {
RLOG_FATAL << "Operator " << sop << " not supported for MXNDArray";
}
}
return NDArray::RObject(out, true);
}