in src/operator/tensor/elemwise_binary_op-inl.h [45:253]
void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<cpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
const OpReqType req,
const NDArray &output,
const bool lhs_may_be_dense,
const bool rhs_may_be_dense,
const bool allow_inplace,
const bool scatter) {
using namespace mshadow;
using namespace mshadow::expr;
const NDArray& rsp = lhs.storage_type() == kRowSparseStorage ? lhs : rhs;
const bool is_dense_result = output.storage_type() == kDefaultStorage;
const bool lhs_is_dense = lhs.storage_type() == kDefaultStorage;
const bool rhs_is_dense = rhs.storage_type() == kDefaultStorage;
CHECK(!lhs_is_dense || lhs_may_be_dense) << "rvalue cannot be dense";
CHECK(!rhs_is_dense || rhs_may_be_dense) << "rvalue cannot be dense";
CHECK(!lhs_is_dense || !rhs_is_dense);
MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, {
MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
// Only one item at most may be dense (lhs, rhs or result)
if (rhs_is_dense) {
// For right-side dense, in order to have sparse output, lhs input zero should
// always output zero
CHECK(std::fabs(static_cast<float>(OP::Map(DType(0), DType(99)))) < 1e-4f);
CHECK(!is_dense_result); // Currently not handled
}
if (lhs_is_dense) {
// For left-side dense, in order to have sparse output, lhs input zero should
// always output zero
CHECK(std::fabs(static_cast<float>(OP::Map(DType(99), DType(0)))) < 1e-4f);
CHECK(!is_dense_result); // Currently not handled
}
// Memory Estimation: This is (roughly) the number of result rows. We may still
// need to subtract the number of common rows
bool lhs_in_place = false, rhs_in_place = false;
const size_t num_rows_l = lhs_is_dense ? lhs.shape()[0] :
lhs.aux_shape(rowsparse::kIdx).Size();
const size_t num_rows_r = rhs_is_dense ? rhs.shape()[0] :
rhs.aux_shape(rowsparse::kIdx).Size();
if (is_dense_result) {
output.CheckAndAlloc();
} else {
if (rhs_is_dense || scatter) {
output.CheckAndAlloc({mshadow::Shape1(num_rows_l)});
} else if (lhs_is_dense) {
output.CheckAndAlloc({mshadow::Shape1(num_rows_r)});
} else {
lhs_in_place = IsSameArray(lhs, output);
rhs_in_place = IsSameArray(rhs, output);
if (!lhs_in_place && !rhs_in_place) {
output.CheckAndAlloc({mshadow::Shape1(num_rows_l + num_rows_r)});
} else {
CHECK_EQ(allow_inplace, true);
CHECK_EQ(is_dense_result, false);
if (lhs_in_place) {
// For in-place, zero L-value must always be zero output
DCHECK(std::fabs(static_cast<float>(OP::Map(DType(0), DType(99)))) < DType(1e-3));
} else {
// For in-place, zero R-value must always be zero output
DCHECK(std::fabs(static_cast<float>(OP::Map(DType(99), DType(0)))) < DType(1e-3));
}
}
}
}
// Indices
const Tensor<cpu, 1, IType> indices_l = lhs_is_dense ?
Tensor<cpu, 1, IType>() :
lhs.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(s);
const Tensor<cpu, 1, IType> indices_r = rhs_is_dense ?
Tensor<cpu, 1, IType>() :
rhs.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(s);
Tensor<cpu, 1, IType> indices_out = is_dense_result ?
Tensor<cpu, 1, IType>() :
output.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(s);
// Data
// TODO(cjolivier01): Change to get_with_shape() calls
const Tensor<cpu, 2, DType> data_l = AsRowise2D<DType>(s, lhs.data());
const Tensor<cpu, 2, DType> data_r = AsRowise2D<DType>(s, rhs.data());
Tensor<cpu, 2, DType> out = AsRowise2D<DType>(s, output.data());
size_t iter_l = 0;
size_t iter_r = 0;
size_t iter_out = 0;
int32_t num_common_rows = 0;
if (is_dense_result) {
if (!num_rows_l && !num_rows_r) {
const size_t all_rows = static_cast<size_t>(lhs.shape()[0]);
iter_out = FillDense<DType, OP>(s, all_rows, all_rows, req, &out, iter_out);
}
}
while (iter_l < num_rows_l && iter_r < num_rows_r) {
IType idx_l = lhs_is_dense ? indices_r[iter_r] : indices_l[iter_l];
IType idx_r = rhs_is_dense ? idx_l : indices_r[iter_r];
if (lhs_in_place) {
while (idx_r < idx_l && ++iter_r < num_rows_r) {
idx_r = indices_r[iter_r];
}
if (iter_r >= num_rows_r) {
break;
}
} else if (rhs_in_place) {
while (idx_l < idx_r && ++iter_l < num_rows_l) {
idx_l = indices_l[iter_l];
}
if (iter_l >= num_rows_l) {
break;
}
}
if (is_dense_result) {
iter_out = FillDense<DType, OP>(s, idx_l, idx_r, req, &out, iter_out);
DCHECK_EQ(iter_out, static_cast<size_t>(std::min(idx_l, idx_r)));
}
if (idx_l == idx_r) {
// Same row
if (!is_dense_result) {
indices_out[iter_out] = idx_l;
}
Tensor<cpu, 1, DType> lvalue = !lhs_is_dense ? data_l[iter_l++] : data_l[idx_l];
Tensor<cpu, 1, DType> rvalue = !rhs_is_dense ? data_r[iter_r++] : data_r[idx_r];
DCHECK_EQ(lvalue.shape_.Size(), rvalue.shape_.Size());
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, cpu>::Launch(
s, lvalue.shape_.Size(), out[iter_out].dptr_, lvalue.dptr_, rvalue.dptr_);
});
num_common_rows++;
} else if (idx_l < idx_r) {
// Left only
if (!is_dense_result) {
indices_out[iter_out] = idx_l;
}
Tensor<cpu, 1, DType> lvalue = !lhs_is_dense ? data_l[iter_l++] : data_l[idx_l];
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
mxnet_op::Kernel<MissingRValueOp<OP, Req>, cpu>::Launch(
s, lvalue.shape_.Size(), out[iter_out].dptr_, lvalue.dptr_);
});
} else {
// Right only
if (scatter) {
++iter_r;
continue; // skip '++iter_out' below
}
if (!is_dense_result) {
indices_out[iter_out] = idx_r;
}
Tensor<cpu, 1, DType> rvalue = !rhs_is_dense ? data_r[iter_r++] : data_r[idx_r];
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
mxnet_op::Kernel<MissingLValueOp<OP, Req>, cpu>::Launch(
s, rvalue.shape_.Size(), out[iter_out].dptr_, rvalue.dptr_);
});
}
++iter_out;
}
// Evaluate the remaining rows beyond the l and r value row intersetion
while (iter_l < num_rows_l && !lhs_is_dense && !rhs_in_place) {
if (!is_dense_result) {
indices_out[iter_out] = indices_l[iter_l];
} else {
const IType idx_l = indices_l[iter_l];
iter_out = FillDense<DType, OP>(s, lhs.shape()[0], idx_l, req, &out, iter_out);
}
Tensor<cpu, 1, DType> lvalue = data_l[iter_l++];
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
mxnet_op::Kernel<MissingRValueOp<OP, Req>, cpu>::Launch(
s, lvalue.shape_.Size(), out[iter_out++].dptr_, lvalue.dptr_);
});
}
while (iter_r < num_rows_r && !rhs_is_dense && !lhs_in_place && !scatter) {
if (!is_dense_result) {
indices_out[iter_out] = indices_r[iter_r];
} else {
const IType idx_r = indices_r[iter_r];
iter_out = FillDense<DType, OP>(s, lhs.shape()[0], idx_r, req, &out, iter_out);
}
Tensor<cpu, 1, DType> rvalue = data_r[iter_r++];
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
mxnet_op::Kernel<MissingLValueOp<OP, Req>, cpu>::Launch(
s, rvalue.shape_.Size(), out[iter_out++].dptr_, rvalue.dptr_);
});
}
if (is_dense_result) {
const size_t all_rows = static_cast<size_t>(lhs.shape()[0]);
iter_out = FillDense<DType, OP>(s, all_rows, all_rows, req, &out, iter_out);
} else {
if (lhs_in_place) {
CHECK_LE(iter_out, num_rows_l);
}
if (rhs_in_place) {
CHECK_LE(iter_out, num_rows_r);
}
DCHECK_LE(iter_out, num_rows_l + num_rows_r); // Make sure that we didn't overrun
nnvm::TShape new_shape = output.aux_shape(rowsparse::kIdx);
CHECK_LE(iter_out, new_shape.Size());
if (!rhs_is_dense && !lhs_is_dense && !lhs_in_place && !rhs_in_place && !scatter) {
// Reduce the first-dimension size by the number of common rows
new_shape[0] -= num_common_rows;
output.set_aux_shape(rowsparse::kIdx, new_shape);
}
}
});
});
}