in lingvo/core/ops/generic_input_op_kernels.cc [189:329]
Status Merge(int64_t bucket_size, const std::vector<TensorVec>& samples,
TensorVec* batch) override {
CHECK(!samples.empty());
const auto num_samples = samples.size();
const auto num_outs = samples[0].size();
std::vector<TensorVec> padded_samples(samples.begin(), samples.end());
if (!dynamic_padding_dimensions_.empty()) {
CHECK(dynamic_padding_dimensions_.size() == num_outs);
CHECK(dynamic_padding_constants_.size() == num_outs);
for (int j = 0; j < num_outs; ++j) {
const int pad_dim = dynamic_padding_dimensions_[j];
if (pad_dim == -1) {
continue;
}
const int pad_value = dynamic_padding_constants_[j];
int64_t max_length = 0;
for (int i = 0; i < samples.size(); ++i) {
max_length = std::max(max_length, samples[i][j].dim_size(pad_dim));
}
for (int i = 0; i < samples.size(); ++i) {
const auto& src = samples[i][j];
if (src.dims() > 0 && src.dim_size(pad_dim) < max_length) {
DataType dtype = src.dtype();
TensorShape dst_shape(src.shape());
dst_shape.set_dim(pad_dim, max_length);
Tensor dst(dtype, dst_shape);
switch (dtype) {
#define CASE(T) \
case DataTypeToEnum<T>::value: \
dst.flat<T>().setConstant(pad_value); \
if (src.NumElements() > 0) { \
auto src_t = src.flat_inner_outer_dims<T, 2>(pad_dim - 1); \
auto dst_t = dst.flat_inner_outer_dims<T, 2>(pad_dim - 1); \
typedef Eigen::DSizes<Eigen::DenseIndex, 2> DSizes; \
dst_t.slice(DSizes(), DSizes(src_t.dimensions())) = src_t; \
} \
break
CASE(float);
CASE(int32);
CASE(int64_t);
#undef CASE
default:
LOG(FATAL) << "Unexpected " << DataTypeString(dtype);
}
std::swap(padded_samples[i][j], dst);
}
}
}
}
// Validate that samples can be merged: samples[:][i] has the same
// type and shape.
for (int i = 1; i < padded_samples.size(); ++i) {
if (padded_samples[i].size() != num_outs) {
LOG(FATAL) << "Samples have different sizes: " << samples[i].size()
<< " vs. " << num_outs;
}
for (int j = 0; j < num_outs; ++j) {
if (padded_samples[i][j].dtype() != padded_samples[0][j].dtype()) {
LOG(FATAL) << "Mismatch data types of samples (" << i << "/" << j
<< "): " << samples[i][j].dtype() << " vs. "
<< samples[0][j].dtype();
}
if (padded_samples[i][j].shape() != padded_samples[0][j].shape()) {
LOG(FATAL) << "Mismatch shape of samples (" << i << "/" << j
<< "): " << samples[i][j].shape().DebugString() << " vs. "
<< samples[0][j].shape().DebugString();
}
}
}
batch->clear();
for (int i = 0; i < num_outs; ++i) {
const Tensor& src = padded_samples[0][i];
DataType dtype = src.dtype();
switch (dtype) {
case DT_FLOAT:
case DT_UINT8:
case DT_INT32:
case DT_INT64:
case DT_STRING:
case DT_BFLOAT16:
case DT_COMPLEX64:
case DT_COMPLEX128:
case DT_BOOL:
break;
default:
LOG(FATAL) << DataTypeString(dtype) << " is not supported.";
}
TensorShape shape = src.shape();
shape.InsertDim(0, num_samples);
// The merged tensor is 1-rank higher and its 1st dimension
// is the num_samples.
if (num_samples == 1) {
// Avoid memcpy if there is just one sample.
Tensor reshaped(dtype);
CHECK(reshaped.CopyFrom(src, shape));
batch->push_back(reshaped);
} else {
batch->push_back(Tensor(dtype, shape));
}
}
// If there is just one sample, 'batch' already has the copy.
if (num_samples == 1) return Status::OK();
Sharder::Do(
num_samples /* total */, 1000 /* cost_per_unit */,
[&](int64_t start, int64_t limit) {
for (int i = 0; i < num_outs; ++i) {
DataType dtype = padded_samples[0][i].dtype();
Tensor* merged = &(*batch)[i];
for (int j = start; j < limit; ++j) {
switch (dtype) {
#define CASE(T) \
case DataTypeToEnum<T>::value: \
merged->flat_outer_dims<T>().chip<0>(j) = padded_samples[j][i].flat<T>(); \
break
CASE(float);
CASE(int32);
CASE(int64_t);
CASE(tstring);
CASE(uint8);
CASE(bfloat16);
CASE(complex64);
CASE(complex128);
CASE(bool);
#undef CASE
default:
LOG(FATAL) << "Unexpected " << DataTypeString(dtype);
}
}
}
},
merger_runner_, 1 + num_merger_threads_);
return Status::OK();
}