void SimpleOpRegEntryImpl::RegisterBinaryImperative()

in src/operator/operator_util.cc [850:978]


void SimpleOpRegEntryImpl::RegisterBinaryImperative() {
  CHECK_EQ(reg_counter_, 1);
  // The body to be registered
  auto body = [this] (NDArray** used_vars,
                      real_t* s,
                      NDArray** mutate_vars,
                      int num_params,
                      char** param_keys,
                      char** param_vals) {
    NDArray& lhs = *used_vars[0];
    NDArray& rhs = *used_vars[1];
    NDArray* out = mutate_vars[0];
    // setup env.
    EnvArguments env;
    if (enable_scalar_) env.scalar = s[0];
    if (enable_kwargs_) {
      for (int i = 0; i < num_params; ++i) {
        env.kwargs.emplace_back(std::make_pair(
            std::string(param_keys[i]), std::string(param_vals[i])));
      }
    } else {
      CHECK_EQ(num_params, 0)
        << "operator " << this->name << " do not take keyword arguments";
    }

    // shape inference.
    TShape dshape;
    if (binary_shape_ != nullptr) {
      dshape = binary_shape_(lhs.shape(), rhs.shape(), env);
    } else {
      CHECK_EQ(lhs.shape(), rhs.shape()) << "operands shape mismatch";
      dshape = lhs.shape();
    }

    // no check if all of them are on cpu
    if (lhs.ctx().dev_mask() != cpu::kDevMask || rhs.ctx().dev_mask() != cpu::kDevMask) {
      CHECK(lhs.ctx() == rhs.ctx())
        << "operands context mismatch " << lhs.ctx().dev_type << " " << lhs.ctx().dev_id << \
        " vs. " << rhs.ctx().dev_type << " " << rhs.ctx().dev_id;
    }
    CHECK_EQ(lhs.dtype(), rhs.dtype()) << "operands type mismatch";

    // check output shape.
    if (out->is_none()) {
      *out = NDArray(dshape, lhs.ctx(), true, lhs.dtype());
    } else {
      CHECK(out->ctx() == lhs.ctx()) << "target context mismatch";
      CHECK(out->dtype() == lhs.dtype()) << "target data type mismatch";
      CHECK(out->shape() == dshape) << "target shape mismatch "
      << out->shape() << " vs. " << dshape;
    }
    // important: callback must always capture by value
    NDArray ret = *out;
    // get the const variables
    std::vector<Engine::VarHandle> const_vars;
    if (lhs.var() != ret.var()) const_vars.push_back(lhs.var());
    if (rhs.var() != ret.var()) const_vars.push_back(rhs.var());

    // request resources.
    std::vector<Engine::VarHandle> write_vars = {ret.var()};
    for (ResourceRequest req : resource_requests_) {
      env.resource.push_back(ResourceManager::Get()->Request(lhs.ctx(), req));
      write_vars.push_back(env.resource.back().var);
    }

    // check if the function exist
    int dev_mask = lhs.ctx().dev_mask();
    // error message
    if (static_cast<size_t>(dev_mask) >= fbinary_.size() ||
        fbinary_[dev_mask] == nullptr) {
      if (dev_mask == gpu::kDevMask) {
        LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
      }
      LOG(FATAL) << "Function " << this->name
                 << "not registered for device " << dev_mask;
    }
    // invoke the function
    BinaryFunction fun = fbinary_[dev_mask];
    OpReqType req = kWriteTo;
    if (lhs.var() == ret.var()) {
      req = kWriteInplace;
      CHECK(binary_forward_inplace_lhs_out_)
          << "inplace operation is not enabled for operator " << name;
    }
    if (rhs.var() == ret.var()) {
      LOG(ERROR) << " operation " << this->name
        << " warning, perform inplace operation with right operand, may not be supported";
    }

    Engine::Get()->PushSync([lhs, rhs, ret, fun, dev_mask, req, env](RunContext ctx) {
        TBlob tmp = ret.data();
        (*fun)(lhs.data(), rhs.data(), env, &tmp, req, ctx);
        #if MXNET_USE_CUDA
        if (dev_mask == gpu::kDevMask) {
          ctx.get_stream<gpu>()->Wait();
        }
        #endif
      }, lhs.ctx(), const_vars, write_vars,
      FnProperty::kNormal, 0, PROFILER_MESSAGE("RegisterBinaryImperative"));
  };
  // register the function.
  NDArrayReg()
      .set_body(body)
      .set_num_use_vars(2)
      .set_num_mutate_vars(1);
  if (enable_scalar_) {
    if (scalar_type_mask_ == kArrayBeforeScalar) {
      NDArrayReg()
          .set_num_scalars(1)
          .set_type_mask(kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget)
          .add_argument("lhs", "NDArray-or-Symbol", "Left operand  to the function")
          .add_argument("rhs", "NDArray-or-Symbol", "Right operand to the function")
          .add_argument("scalar", "float", "scalar input to the function");
    } else {
      NDArrayReg()
          .set_num_scalars(1)
          .set_type_mask(kScalarArgBeforeNDArray | kAcceptEmptyMutateTarget)
          .add_argument("scalar", "float", "scalar input to the function")
          .add_argument("src", "NDArray-or-Symbol", "Source input to the function")
          .add_argument("lhs", "NDArray-or-Symbol", "Left operand  to the function")
          .add_argument("rhs", "NDArray-or-Symbol", "Right operand to the function");
    }
  } else {
    NDArrayReg()
        .set_type_mask(kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget)
        .add_argument("lhs", "NDArray-or-Symbol", "Left operand  to the function")
        .add_argument("rhs", "NDArray-or-Symbol", "Right operand to the function");
  }
}