in src/operator/operator_util.cc [850:978]
void SimpleOpRegEntryImpl::RegisterBinaryImperative() {
CHECK_EQ(reg_counter_, 1);
// The body to be registered
auto body = [this] (NDArray** used_vars,
real_t* s,
NDArray** mutate_vars,
int num_params,
char** param_keys,
char** param_vals) {
NDArray& lhs = *used_vars[0];
NDArray& rhs = *used_vars[1];
NDArray* out = mutate_vars[0];
// setup env.
EnvArguments env;
if (enable_scalar_) env.scalar = s[0];
if (enable_kwargs_) {
for (int i = 0; i < num_params; ++i) {
env.kwargs.emplace_back(std::make_pair(
std::string(param_keys[i]), std::string(param_vals[i])));
}
} else {
CHECK_EQ(num_params, 0)
<< "operator " << this->name << " do not take keyword arguments";
}
// shape inference.
TShape dshape;
if (binary_shape_ != nullptr) {
dshape = binary_shape_(lhs.shape(), rhs.shape(), env);
} else {
CHECK_EQ(lhs.shape(), rhs.shape()) << "operands shape mismatch";
dshape = lhs.shape();
}
// no check if all of them are on cpu
if (lhs.ctx().dev_mask() != cpu::kDevMask || rhs.ctx().dev_mask() != cpu::kDevMask) {
CHECK(lhs.ctx() == rhs.ctx())
<< "operands context mismatch " << lhs.ctx().dev_type << " " << lhs.ctx().dev_id << \
" vs. " << rhs.ctx().dev_type << " " << rhs.ctx().dev_id;
}
CHECK_EQ(lhs.dtype(), rhs.dtype()) << "operands type mismatch";
// check output shape.
if (out->is_none()) {
*out = NDArray(dshape, lhs.ctx(), true, lhs.dtype());
} else {
CHECK(out->ctx() == lhs.ctx()) << "target context mismatch";
CHECK(out->dtype() == lhs.dtype()) << "target data type mismatch";
CHECK(out->shape() == dshape) << "target shape mismatch "
<< out->shape() << " vs. " << dshape;
}
// important: callback must always capture by value
NDArray ret = *out;
// get the const variables
std::vector<Engine::VarHandle> const_vars;
if (lhs.var() != ret.var()) const_vars.push_back(lhs.var());
if (rhs.var() != ret.var()) const_vars.push_back(rhs.var());
// request resources.
std::vector<Engine::VarHandle> write_vars = {ret.var()};
for (ResourceRequest req : resource_requests_) {
env.resource.push_back(ResourceManager::Get()->Request(lhs.ctx(), req));
write_vars.push_back(env.resource.back().var);
}
// check if the function exist
int dev_mask = lhs.ctx().dev_mask();
// error message
if (static_cast<size_t>(dev_mask) >= fbinary_.size() ||
fbinary_[dev_mask] == nullptr) {
if (dev_mask == gpu::kDevMask) {
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
}
LOG(FATAL) << "Function " << this->name
<< "not registered for device " << dev_mask;
}
// invoke the function
BinaryFunction fun = fbinary_[dev_mask];
OpReqType req = kWriteTo;
if (lhs.var() == ret.var()) {
req = kWriteInplace;
CHECK(binary_forward_inplace_lhs_out_)
<< "inplace operation is not enabled for operator " << name;
}
if (rhs.var() == ret.var()) {
LOG(ERROR) << " operation " << this->name
<< " warning, perform inplace operation with right operand, may not be supported";
}
Engine::Get()->PushSync([lhs, rhs, ret, fun, dev_mask, req, env](RunContext ctx) {
TBlob tmp = ret.data();
(*fun)(lhs.data(), rhs.data(), env, &tmp, req, ctx);
#if MXNET_USE_CUDA
if (dev_mask == gpu::kDevMask) {
ctx.get_stream<gpu>()->Wait();
}
#endif
}, lhs.ctx(), const_vars, write_vars,
FnProperty::kNormal, 0, PROFILER_MESSAGE("RegisterBinaryImperative"));
};
// register the function.
NDArrayReg()
.set_body(body)
.set_num_use_vars(2)
.set_num_mutate_vars(1);
if (enable_scalar_) {
if (scalar_type_mask_ == kArrayBeforeScalar) {
NDArrayReg()
.set_num_scalars(1)
.set_type_mask(kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget)
.add_argument("lhs", "NDArray-or-Symbol", "Left operand to the function")
.add_argument("rhs", "NDArray-or-Symbol", "Right operand to the function")
.add_argument("scalar", "float", "scalar input to the function");
} else {
NDArrayReg()
.set_num_scalars(1)
.set_type_mask(kScalarArgBeforeNDArray | kAcceptEmptyMutateTarget)
.add_argument("scalar", "float", "scalar input to the function")
.add_argument("src", "NDArray-or-Symbol", "Source input to the function")
.add_argument("lhs", "NDArray-or-Symbol", "Left operand to the function")
.add_argument("rhs", "NDArray-or-Symbol", "Right operand to the function");
}
} else {
NDArrayReg()
.set_type_mask(kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget)
.add_argument("lhs", "NDArray-or-Symbol", "Left operand to the function")
.add_argument("rhs", "NDArray-or-Symbol", "Right operand to the function");
}
}