void ElemwiseBinaryOp::RspRspOp()

in src/operator/tensor/elemwise_binary_op-inl.h [45:253]


void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<cpu> *s,
                                const nnvm::NodeAttrs &attrs,
                                const OpContext &ctx,
                                const NDArray &lhs,
                                const NDArray &rhs,
                                const OpReqType req,
                                const NDArray &output,
                                const bool lhs_may_be_dense,
                                const bool rhs_may_be_dense,
                                const bool allow_inplace,
                                const bool scatter) {
  using namespace mshadow;
  using namespace mshadow::expr;
  const NDArray& rsp = lhs.storage_type() == kRowSparseStorage ? lhs : rhs;
  const bool is_dense_result = output.storage_type() == kDefaultStorage;
  const bool lhs_is_dense = lhs.storage_type() == kDefaultStorage;
  const bool rhs_is_dense = rhs.storage_type() == kDefaultStorage;
  CHECK(!lhs_is_dense || lhs_may_be_dense) << "rvalue cannot be dense";
  CHECK(!rhs_is_dense || rhs_may_be_dense) << "rvalue cannot be dense";
  CHECK(!lhs_is_dense || !rhs_is_dense);
  MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, {
    MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
      // Only one item at most may be dense (lhs, rhs or result)
      if (rhs_is_dense) {
        // For right-side dense, in order to have sparse output, lhs input zero should
        // always output zero
        CHECK(std::fabs(static_cast<float>(OP::Map(DType(0), DType(99)))) < 1e-4f);
        CHECK(!is_dense_result);  // Currently not handled
      }
      if (lhs_is_dense) {
        // For left-side dense, in order to have sparse output, lhs input zero should
        // always output zero
        CHECK(std::fabs(static_cast<float>(OP::Map(DType(99), DType(0)))) < 1e-4f);
        CHECK(!is_dense_result);  // Currently not handled
      }

      // Memory Estimation: This is (roughly) the number of result rows. We may still
      // need to subtract the number of common rows
      bool lhs_in_place = false, rhs_in_place = false;
      const size_t num_rows_l = lhs_is_dense ? lhs.shape()[0] :
                                               lhs.aux_shape(rowsparse::kIdx).Size();
      const size_t num_rows_r = rhs_is_dense ? rhs.shape()[0] :
                                               rhs.aux_shape(rowsparse::kIdx).Size();
      if (is_dense_result) {
        output.CheckAndAlloc();
      } else {
        if (rhs_is_dense || scatter) {
          output.CheckAndAlloc({mshadow::Shape1(num_rows_l)});
        } else if (lhs_is_dense) {
          output.CheckAndAlloc({mshadow::Shape1(num_rows_r)});
        } else {
          lhs_in_place = IsSameArray(lhs, output);
          rhs_in_place = IsSameArray(rhs, output);
          if (!lhs_in_place && !rhs_in_place) {
            output.CheckAndAlloc({mshadow::Shape1(num_rows_l + num_rows_r)});
          } else {
            CHECK_EQ(allow_inplace, true);
            CHECK_EQ(is_dense_result, false);
            if (lhs_in_place) {
              // For in-place, zero L-value must always be zero output
              DCHECK(std::fabs(static_cast<float>(OP::Map(DType(0), DType(99)))) < DType(1e-3));
            } else {
              // For in-place, zero R-value must always be zero output
              DCHECK(std::fabs(static_cast<float>(OP::Map(DType(99), DType(0)))) < DType(1e-3));
            }
          }
        }
      }

      // Indices
      const Tensor<cpu, 1, IType> indices_l = lhs_is_dense ?
                                              Tensor<cpu, 1, IType>() :
                                              lhs.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(s);
      const Tensor<cpu, 1, IType> indices_r = rhs_is_dense ?
                                              Tensor<cpu, 1, IType>() :
                                              rhs.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(s);
      Tensor<cpu, 1, IType> indices_out = is_dense_result ?
                                          Tensor<cpu, 1, IType>() :
                                          output.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(s);

      // Data
      // TODO(cjolivier01): Change to get_with_shape() calls
      const Tensor<cpu, 2, DType> data_l = AsRowise2D<DType>(s, lhs.data());
      const Tensor<cpu, 2, DType> data_r = AsRowise2D<DType>(s, rhs.data());
      Tensor<cpu, 2, DType> out = AsRowise2D<DType>(s, output.data());

      size_t iter_l = 0;
      size_t iter_r = 0;
      size_t iter_out = 0;
      int32_t num_common_rows = 0;

      if (is_dense_result) {
        if (!num_rows_l && !num_rows_r) {
          const size_t all_rows = static_cast<size_t>(lhs.shape()[0]);
          iter_out = FillDense<DType, OP>(s, all_rows, all_rows, req, &out, iter_out);
        }
      }

      while (iter_l < num_rows_l && iter_r < num_rows_r) {
        IType idx_l = lhs_is_dense ? indices_r[iter_r] : indices_l[iter_l];
        IType idx_r = rhs_is_dense ? idx_l : indices_r[iter_r];
        if (lhs_in_place) {
          while (idx_r < idx_l && ++iter_r < num_rows_r) {
            idx_r = indices_r[iter_r];
          }
          if (iter_r >= num_rows_r) {
            break;
          }
        } else if (rhs_in_place) {
          while (idx_l < idx_r && ++iter_l < num_rows_l) {
            idx_l = indices_l[iter_l];
          }
          if (iter_l >= num_rows_l) {
            break;
          }
        }
        if (is_dense_result) {
          iter_out = FillDense<DType, OP>(s, idx_l, idx_r, req, &out, iter_out);
          DCHECK_EQ(iter_out, static_cast<size_t>(std::min(idx_l, idx_r)));
        }
        if (idx_l == idx_r) {
          // Same row
          if (!is_dense_result) {
            indices_out[iter_out] = idx_l;
          }
          Tensor<cpu, 1, DType> lvalue = !lhs_is_dense ? data_l[iter_l++] : data_l[idx_l];
          Tensor<cpu, 1, DType> rvalue = !rhs_is_dense ? data_r[iter_r++] : data_r[idx_r];
          DCHECK_EQ(lvalue.shape_.Size(), rvalue.shape_.Size());
          MXNET_ASSIGN_REQ_SWITCH(req, Req, {
            mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, cpu>::Launch(
              s, lvalue.shape_.Size(), out[iter_out].dptr_, lvalue.dptr_, rvalue.dptr_);
          });
          num_common_rows++;
        } else if (idx_l < idx_r) {
          // Left only
          if (!is_dense_result) {
            indices_out[iter_out] = idx_l;
          }
          Tensor<cpu, 1, DType> lvalue = !lhs_is_dense ? data_l[iter_l++] : data_l[idx_l];
          MXNET_ASSIGN_REQ_SWITCH(req, Req, {
            mxnet_op::Kernel<MissingRValueOp<OP, Req>, cpu>::Launch(
              s, lvalue.shape_.Size(), out[iter_out].dptr_, lvalue.dptr_);
          });
        } else {
          // Right only
          if (scatter) {
            ++iter_r;
            continue;  // skip '++iter_out' below
          }
          if (!is_dense_result) {
            indices_out[iter_out] = idx_r;
          }
          Tensor<cpu, 1, DType> rvalue = !rhs_is_dense ? data_r[iter_r++] : data_r[idx_r];
          MXNET_ASSIGN_REQ_SWITCH(req, Req, {
            mxnet_op::Kernel<MissingLValueOp<OP, Req>, cpu>::Launch(
              s, rvalue.shape_.Size(), out[iter_out].dptr_, rvalue.dptr_);
          });
        }
        ++iter_out;
      }
      // Evaluate the remaining rows beyond the l and r value row intersetion
      while (iter_l < num_rows_l && !lhs_is_dense && !rhs_in_place) {
        if (!is_dense_result) {
          indices_out[iter_out] = indices_l[iter_l];
        } else {
          const IType idx_l = indices_l[iter_l];
          iter_out = FillDense<DType, OP>(s, lhs.shape()[0], idx_l, req, &out, iter_out);
        }
        Tensor<cpu, 1, DType> lvalue = data_l[iter_l++];
        MXNET_ASSIGN_REQ_SWITCH(req, Req, {
          mxnet_op::Kernel<MissingRValueOp<OP, Req>, cpu>::Launch(
            s, lvalue.shape_.Size(), out[iter_out++].dptr_, lvalue.dptr_);
        });
      }
      while (iter_r < num_rows_r && !rhs_is_dense && !lhs_in_place && !scatter) {
        if (!is_dense_result) {
          indices_out[iter_out] = indices_r[iter_r];
        } else {
          const IType idx_r = indices_r[iter_r];
          iter_out = FillDense<DType, OP>(s, lhs.shape()[0], idx_r, req, &out, iter_out);
        }
        Tensor<cpu, 1, DType> rvalue = data_r[iter_r++];
        MXNET_ASSIGN_REQ_SWITCH(req, Req, {
          mxnet_op::Kernel<MissingLValueOp<OP, Req>, cpu>::Launch(
            s, rvalue.shape_.Size(), out[iter_out++].dptr_, rvalue.dptr_);
        });
      }
      if (is_dense_result) {
        const size_t all_rows = static_cast<size_t>(lhs.shape()[0]);
        iter_out = FillDense<DType, OP>(s, all_rows, all_rows, req, &out, iter_out);
      } else {
        if (lhs_in_place) {
          CHECK_LE(iter_out, num_rows_l);
        }
        if (rhs_in_place) {
          CHECK_LE(iter_out, num_rows_r);
        }
        DCHECK_LE(iter_out, num_rows_l + num_rows_r);  // Make sure that we didn't overrun
        nnvm::TShape new_shape = output.aux_shape(rowsparse::kIdx);
        CHECK_LE(iter_out, new_shape.Size());
        if (!rhs_is_dense && !lhs_is_dense && !lhs_in_place && !rhs_in_place && !scatter) {
          // Reduce the first-dimension size by the number of common rows
          new_shape[0] -= num_common_rows;
          output.set_aux_shape(rowsparse::kIdx, new_shape);
        }
      }
    });
  });
}