NDArray::RObjectType DispatchOps()

in R-package/src/ndarray.cc [557:709]


NDArray::RObjectType DispatchOps(SEXP op, SEXP lhs, SEXP rhs) {
  // function handles
  static OpHandle plus = NDArrayFunction::FindHandle("_plus");
  static OpHandle plus_scalar = NDArrayFunction::FindHandle("_plus_scalar");
  static OpHandle minus = NDArrayFunction::FindHandle("_minus");
  static OpHandle minus_scalar = NDArrayFunction::FindHandle("_minus_scalar");
  static OpHandle rminus_scalar = NDArrayFunction::FindHandle("_rminus_scalar");
  static OpHandle mul = NDArrayFunction::FindHandle("_mul");
  static OpHandle mul_scalar = NDArrayFunction::FindHandle("_mul_scalar");
  static OpHandle div = NDArrayFunction::FindHandle("_div");
  static OpHandle div_scalar = NDArrayFunction::FindHandle("_div_scalar");
  static OpHandle rdiv_scalar = NDArrayFunction::FindHandle("_rdiv_scalar");
  static OpHandle mod = NDArrayFunction::FindHandle("_mod");
  static OpHandle mod_scalar = NDArrayFunction::FindHandle("_mod_scalar");
  static OpHandle rmod_scalar = NDArrayFunction::FindHandle("_rmod_scalar");
  static OpHandle equal = NDArrayFunction::FindHandle("_equal");
  static OpHandle equal_scalar = NDArrayFunction::FindHandle("_equal_scalar");
  static OpHandle not_equal = NDArrayFunction::FindHandle("_not_equal");
  static OpHandle not_equal_scalar = NDArrayFunction::FindHandle("_not_equal_scalar");
  static OpHandle greater = NDArrayFunction::FindHandle("_greater");
  static OpHandle greater_scalar = NDArrayFunction::FindHandle("_greater_scalar");
  static OpHandle greater_equal = NDArrayFunction::FindHandle("_greater_equal");
  static OpHandle greater_equal_scalar = NDArrayFunction::FindHandle("_greater_equal_scalar");
  static OpHandle lesser = NDArrayFunction::FindHandle("_lesser");
  static OpHandle lesser_scalar = NDArrayFunction::FindHandle("_lesser_scalar");
  static OpHandle lesser_equal = NDArrayFunction::FindHandle("_lesser_equal");
  static OpHandle lesser_equal_scalar = NDArrayFunction::FindHandle("_lesser_equal_scalar");
  // parse the arguments
  std::string values[2];
  NDArrayHandle handles[2];
  NDArrayHandle out = nullptr;
  bool lhs_nd = ParseNDArrayArg(lhs, &handles[0], &values[0]);
  bool rhs_nd = ParseNDArrayArg(rhs, &handles[1], &values[1]);
  RCHECK(lhs_nd || rhs_nd);
  // create output and dispatch.
  std::string sop = Rcpp::as<std::string>(op);
  switch (sop[0]) {
    case '+': {
      if (lhs_nd && rhs_nd) {
        out = BinaryOp(plus, handles);
      } else if (lhs_nd && !rhs_nd) {
        out = BinaryScalarOp(plus_scalar, handles[0], values[1]);
      } else {
        out = BinaryScalarOp(plus_scalar, handles[1], values[0]);
      }
      break;
    }
    case '-': {
      if (lhs_nd && rhs_nd) {
        out = BinaryOp(minus, handles);
      } else if (lhs_nd && !rhs_nd) {
        out = BinaryScalarOp(minus_scalar, handles[0], values[1]);
      } else {
        out = BinaryScalarOp(rminus_scalar, handles[1], values[0]);
      }
      break;
    }
    case '*': {
      if (lhs_nd && rhs_nd) {
        out = BinaryOp(mul, handles);
      } else if (lhs_nd && !rhs_nd) {
        out = BinaryScalarOp(mul_scalar, handles[0], values[1]);
      } else {
        out = BinaryScalarOp(mul_scalar, handles[1], values[0]);
      }
      break;
    }
    case '/': {
      if (lhs_nd && rhs_nd) {
        out = BinaryOp(div, handles);
      } else if (lhs_nd && !rhs_nd) {
        out = BinaryScalarOp(div_scalar, handles[0], values[1]);
      } else {
        out = BinaryScalarOp(rdiv_scalar, handles[1], values[0]);
      }
      break;
    }
    case '%': {
      if (lhs_nd && rhs_nd) {
        out = BinaryOp(mod, handles);
      } else if (lhs_nd && !rhs_nd) {
        out = BinaryScalarOp(mod_scalar, handles[0], values[1]);
      } else {
        out = BinaryScalarOp(rmod_scalar, handles[1], values[0]);
      }
      break;
    }
    case '=': {
      if (lhs_nd && rhs_nd) {
        out = BinaryOp(equal, handles);
      } else if (lhs_nd && !rhs_nd) {
        out = BinaryScalarOp(equal_scalar, handles[0], values[1]);
      } else {
        out = BinaryScalarOp(equal_scalar, handles[1], values[0]);
      }
      break;
    }
    case '!': {
      if (lhs_nd && rhs_nd) {
        out = BinaryOp(not_equal, handles);
      } else if (lhs_nd && !rhs_nd) {
        out = BinaryScalarOp(not_equal_scalar, handles[0], values[1]);
      } else {
        out = BinaryScalarOp(not_equal_scalar, handles[1], values[0]);
      }
      break;
    }
    case '>': {
      if (sop == ">=") {
        if (lhs_nd && rhs_nd) {
          out = BinaryOp(greater_equal, handles);
        } else if (lhs_nd && !rhs_nd) {
          out = BinaryScalarOp(greater_equal_scalar, handles[0], values[1]);
        } else {
          out = BinaryScalarOp(lesser_equal_scalar, handles[1], values[0]);
        }
      } else {
        if (lhs_nd && rhs_nd) {
          out = BinaryOp(greater, handles);
        } else if (lhs_nd && !rhs_nd) {
          out = BinaryScalarOp(greater_scalar, handles[0], values[1]);
        } else {
          out = BinaryScalarOp(lesser_scalar, handles[1], values[0]);
        }
      }
      break;
    }
    case '<': {
      if (sop == "<=") {
        if (lhs_nd && rhs_nd) {
          out = BinaryOp(lesser_equal, handles);
        } else if (lhs_nd && !rhs_nd) {
          out = BinaryScalarOp(lesser_equal_scalar, handles[0], values[1]);
        } else {
          out = BinaryScalarOp(greater_equal_scalar, handles[1], values[0]);
        }
      } else {
        if (lhs_nd && rhs_nd) {
          out = BinaryOp(lesser, handles);
        } else if (lhs_nd && !rhs_nd) {
          out = BinaryScalarOp(lesser_scalar, handles[0], values[1]);
        } else {
          out = BinaryScalarOp(greater_scalar, handles[1], values[0]);
        }
      }
      break;
    }
    default: {
      RLOG_FATAL << "Operator " << sop << " not supported for MXNDArray";
    }
  }
  return NDArray::RObject(out, true);
}