in R-package/src/ndarray.cc [538:618]
NDArray::RObjectType DispatchOps(SEXP op, SEXP lhs, SEXP rhs) {
// function handles
static OpHandle plus = NDArrayFunction::FindHandle("_plus");
static OpHandle plus_scalar = NDArrayFunction::FindHandle("_plus_scalar");
static OpHandle minus = NDArrayFunction::FindHandle("_minus");
static OpHandle minus_scalar = NDArrayFunction::FindHandle("_minus_scalar");
static OpHandle rminus_scalar = NDArrayFunction::FindHandle("_rminus_scalar");
static OpHandle mul = NDArrayFunction::FindHandle("_mul");
static OpHandle mul_scalar = NDArrayFunction::FindHandle("_mul_scalar");
static OpHandle div = NDArrayFunction::FindHandle("_div");
static OpHandle div_scalar = NDArrayFunction::FindHandle("_div_scalar");
static OpHandle rdiv_scalar = NDArrayFunction::FindHandle("_rdiv_scalar");
static OpHandle mod = NDArrayFunction::FindHandle("_mod");
static OpHandle mod_scalar = NDArrayFunction::FindHandle("_mod_scalar");
static OpHandle rmod_scalar = NDArrayFunction::FindHandle("_rmod_scalar");
// parse the arguments
std::string values[2];
NDArrayHandle handles[2];
NDArrayHandle out = nullptr;
bool lhs_nd = ParseNDArrayArg(lhs, &handles[0], &values[0]);
bool rhs_nd = ParseNDArrayArg(rhs, &handles[1], &values[1]);
RCHECK(lhs_nd || rhs_nd);
// create output and dispatch.
std::string sop = Rcpp::as<std::string>(op);
switch (sop[0]) {
case '+': {
if (lhs_nd && rhs_nd) {
out = BinaryOp(plus, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(plus_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(plus_scalar, handles[1], values[0]);
}
break;
}
case '-': {
if (lhs_nd && rhs_nd) {
out = BinaryOp(minus, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(minus_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(rminus_scalar, handles[1], values[0]);
}
break;
}
case '*': {
if (lhs_nd && rhs_nd) {
out = BinaryOp(mul, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(mul_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(mul_scalar, handles[1], values[0]);
}
break;
}
case '/': {
if (lhs_nd && rhs_nd) {
out = BinaryOp(div, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(div_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(rdiv_scalar, handles[1], values[0]);
}
break;
}
case '%': {
if (lhs_nd && rhs_nd) {
out = BinaryOp(mod, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(mod_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(rmod_scalar, handles[1], values[0]);
}
break;
}
default: {
RLOG_FATAL << "Operator " << sop << "not supported for MXNDArray";
}
}
return NDArray::RObject(out, true);
}