NDArray::RObjectType DispatchOps()

in R-package/src/ndarray.cc [538:618]


NDArray::RObjectType DispatchOps(SEXP op, SEXP lhs, SEXP rhs) {
  // function handles
  static OpHandle plus = NDArrayFunction::FindHandle("_plus");
  static OpHandle plus_scalar = NDArrayFunction::FindHandle("_plus_scalar");
  static OpHandle minus = NDArrayFunction::FindHandle("_minus");
  static OpHandle minus_scalar = NDArrayFunction::FindHandle("_minus_scalar");
  static OpHandle rminus_scalar = NDArrayFunction::FindHandle("_rminus_scalar");
  static OpHandle mul = NDArrayFunction::FindHandle("_mul");
  static OpHandle mul_scalar = NDArrayFunction::FindHandle("_mul_scalar");
  static OpHandle div = NDArrayFunction::FindHandle("_div");
  static OpHandle div_scalar = NDArrayFunction::FindHandle("_div_scalar");
  static OpHandle rdiv_scalar = NDArrayFunction::FindHandle("_rdiv_scalar");
  static OpHandle mod = NDArrayFunction::FindHandle("_mod");
  static OpHandle mod_scalar = NDArrayFunction::FindHandle("_mod_scalar");
  static OpHandle rmod_scalar = NDArrayFunction::FindHandle("_rmod_scalar");
  // parse the arguments
  std::string values[2];
  NDArrayHandle handles[2];
  NDArrayHandle out = nullptr;
  bool lhs_nd = ParseNDArrayArg(lhs, &handles[0], &values[0]);
  bool rhs_nd = ParseNDArrayArg(rhs, &handles[1], &values[1]);
  RCHECK(lhs_nd || rhs_nd);
  // create output and dispatch.
  std::string sop = Rcpp::as<std::string>(op);
  switch (sop[0]) {
    case '+': {
      if (lhs_nd && rhs_nd) {
        out = BinaryOp(plus, handles);
      } else if (lhs_nd && !rhs_nd) {
        out = BinaryScalarOp(plus_scalar, handles[0], values[1]);
      } else {
        out = BinaryScalarOp(plus_scalar, handles[1], values[0]);
      }
      break;
    }
    case '-': {
      if (lhs_nd && rhs_nd) {
        out = BinaryOp(minus, handles);
      } else if (lhs_nd && !rhs_nd) {
        out = BinaryScalarOp(minus_scalar, handles[0], values[1]);
      } else {
        out = BinaryScalarOp(rminus_scalar, handles[1], values[0]);
      }
      break;
    }
    case '*': {
      if (lhs_nd && rhs_nd) {
        out = BinaryOp(mul, handles);
      } else if (lhs_nd && !rhs_nd) {
        out = BinaryScalarOp(mul_scalar, handles[0], values[1]);
      } else {
        out = BinaryScalarOp(mul_scalar, handles[1], values[0]);
      }
      break;
    }
    case '/': {
      if (lhs_nd && rhs_nd) {
        out = BinaryOp(div, handles);
      } else if (lhs_nd && !rhs_nd) {
        out = BinaryScalarOp(div_scalar, handles[0], values[1]);
      } else {
        out = BinaryScalarOp(rdiv_scalar, handles[1], values[0]);
      }
      break;
    }
    case '%': {
      if (lhs_nd && rhs_nd) {
        out = BinaryOp(mod, handles);
      } else if (lhs_nd && !rhs_nd) {
        out = BinaryScalarOp(mod_scalar, handles[0], values[1]);
      } else {
        out = BinaryScalarOp(rmod_scalar, handles[1], values[0]);
      }
      break;
    }
    default: {
      RLOG_FATAL << "Operator " << sop << "not supported for MXNDArray";
    }
  }
  return NDArray::RObject(out, true);
}