Status Merge()

in lingvo/core/ops/generic_input_op_kernels.cc [189:329]


  Status Merge(int64_t bucket_size, const std::vector<TensorVec>& samples,
               TensorVec* batch) override {
    CHECK(!samples.empty());
    const auto num_samples = samples.size();
    const auto num_outs = samples[0].size();

    std::vector<TensorVec> padded_samples(samples.begin(), samples.end());
    if (!dynamic_padding_dimensions_.empty()) {
      CHECK(dynamic_padding_dimensions_.size() == num_outs);
      CHECK(dynamic_padding_constants_.size() == num_outs);

      for (int j = 0; j < num_outs; ++j) {
        const int pad_dim = dynamic_padding_dimensions_[j];
        if (pad_dim == -1) {
          continue;
        }
        const int pad_value = dynamic_padding_constants_[j];

        int64_t max_length = 0;
        for (int i = 0; i < samples.size(); ++i) {
          max_length = std::max(max_length, samples[i][j].dim_size(pad_dim));
        }

        for (int i = 0; i < samples.size(); ++i) {
          const auto& src = samples[i][j];
          if (src.dims() > 0 && src.dim_size(pad_dim) < max_length) {
            DataType dtype = src.dtype();
            TensorShape dst_shape(src.shape());
            dst_shape.set_dim(pad_dim, max_length);
            Tensor dst(dtype, dst_shape);
            switch (dtype) {
#define CASE(T)                                                  \
  case DataTypeToEnum<T>::value:                                 \
    dst.flat<T>().setConstant(pad_value);                        \
    if (src.NumElements() > 0) {                                 \
      auto src_t = src.flat_inner_outer_dims<T, 2>(pad_dim - 1); \
      auto dst_t = dst.flat_inner_outer_dims<T, 2>(pad_dim - 1); \
      typedef Eigen::DSizes<Eigen::DenseIndex, 2> DSizes;        \
      dst_t.slice(DSizes(), DSizes(src_t.dimensions())) = src_t; \
    }                                                            \
    break

              CASE(float);
              CASE(int32);
              CASE(int64_t);
#undef CASE
              default:
                LOG(FATAL) << "Unexpected " << DataTypeString(dtype);
            }
            std::swap(padded_samples[i][j], dst);
          }
        }
      }
    }

    // Validate that samples can be merged: samples[:][i] has the same
    // type and shape.
    for (int i = 1; i < padded_samples.size(); ++i) {
      if (padded_samples[i].size() != num_outs) {
        LOG(FATAL) << "Samples have different sizes: " << samples[i].size()
                   << " vs. " << num_outs;
      }
      for (int j = 0; j < num_outs; ++j) {
        if (padded_samples[i][j].dtype() != padded_samples[0][j].dtype()) {
          LOG(FATAL) << "Mismatch data types of samples (" << i << "/" << j
                     << "): " << samples[i][j].dtype() << " vs. "
                     << samples[0][j].dtype();
        }
        if (padded_samples[i][j].shape() != padded_samples[0][j].shape()) {
          LOG(FATAL) << "Mismatch shape of samples (" << i << "/" << j
                     << "): " << samples[i][j].shape().DebugString() << " vs. "
                     << samples[0][j].shape().DebugString();
        }
      }
    }

    batch->clear();
    for (int i = 0; i < num_outs; ++i) {
      const Tensor& src = padded_samples[0][i];
      DataType dtype = src.dtype();
      switch (dtype) {
        case DT_FLOAT:
        case DT_UINT8:
        case DT_INT32:
        case DT_INT64:
        case DT_STRING:
        case DT_BFLOAT16:
        case DT_COMPLEX64:
        case DT_COMPLEX128:
        case DT_BOOL:
          break;
        default:
          LOG(FATAL) << DataTypeString(dtype) << " is not supported.";
      }
      TensorShape shape = src.shape();
      shape.InsertDim(0, num_samples);
      // The merged tensor is 1-rank higher and its 1st dimension
      // is the num_samples.
      if (num_samples == 1) {
        // Avoid memcpy if there is just one sample.
        Tensor reshaped(dtype);
        CHECK(reshaped.CopyFrom(src, shape));
        batch->push_back(reshaped);
      } else {
        batch->push_back(Tensor(dtype, shape));
      }
    }
    // If there is just one sample, 'batch' already has the copy.
    if (num_samples == 1) return Status::OK();

    Sharder::Do(
        num_samples /* total */, 1000 /* cost_per_unit */,
        [&](int64_t start, int64_t limit) {
          for (int i = 0; i < num_outs; ++i) {
            DataType dtype = padded_samples[0][i].dtype();
            Tensor* merged = &(*batch)[i];
            for (int j = start; j < limit; ++j) {
              switch (dtype) {
#define CASE(T)                                                               \
  case DataTypeToEnum<T>::value:                                              \
    merged->flat_outer_dims<T>().chip<0>(j) = padded_samples[j][i].flat<T>(); \
    break
                CASE(float);
                CASE(int32);
                CASE(int64_t);
                CASE(tstring);
                CASE(uint8);
                CASE(bfloat16);
                CASE(complex64);
                CASE(complex128);
                CASE(bool);
#undef CASE
                default:
                  LOG(FATAL) << "Unexpected " << DataTypeString(dtype);
              }
            }
          }
        },
        merger_runner_, 1 + num_merger_threads_);
    return Status::OK();
  }