in fairseq/data/multilingual/sampled_multi_dataset.py [0:0]
def collater(self, samples, **extra_args):
"""Merge a list of samples to form a mini-batch."""
if len(samples) == 0:
return None
if self.collate_format == "ordered_dict":
collect_samples = [[] for _ in range(len(self.datasets))]
for (i, sample) in samples:
collect_samples[i].append(sample)
batch = OrderedDict(
[
(self.keys[i], dataset.collater(collect_samples[i]))
for i, (key, dataset) in enumerate(zip(self.keys, self.datasets))
if len(collect_samples[i]) > 0
]
)
elif self.shared_collater:
batch = self.datasets[0].collater([s for _, s in samples])
else:
samples_dict = defaultdict(list)
pad_to_length = (
defaultdict(int)
if "pad_to_length" not in extra_args
else extra_args["pad_to_length"]
)
for ds_idx, s in samples:
pad_to_length["source"] = max(
pad_to_length["source"], s["source"].size(0)
)
if s["target"] is not None:
pad_to_length["target"] = max(
pad_to_length["target"], s["target"].size(0)
)
samples_dict[ds_idx].append(s)
batches = [
self.datasets[i].collater(samples_dict[i], pad_to_length=pad_to_length)
for i in range(len(self.datasets))
if len(samples_dict[i]) > 0
]
def straight_data(tensors):
batch = torch.cat(tensors, dim=0)
return batch
src_lengths = straight_data(
[b["net_input"]["src_lengths"] for b in batches]
)
src_lengths, sort_order = src_lengths.sort(descending=True)
def straight_order(tensors):
batch = straight_data(tensors)
return batch.index_select(0, sort_order)
batch = {
"id": straight_order([b["id"] for b in batches]),
"nsentences": sum(b["nsentences"] for b in batches),
"ntokens": sum(b["ntokens"] for b in batches),
"net_input": {
"src_tokens": straight_order(
[b["net_input"]["src_tokens"] for b in batches]
),
"src_lengths": src_lengths,
},
"target": straight_order([b["target"] for b in batches])
if batches[0]["target"] is not None
else None,
}
if "prev_output_tokens" in batches[0]["net_input"]:
batch["net_input"]["prev_output_tokens"] = straight_order(
[b["net_input"]["prev_output_tokens"] for b in batches]
)
if "src_lang_id" in batches[0]["net_input"]:
batch["net_input"]["src_lang_id"] = straight_order(
[b["net_input"]["src_lang_id"] for b in batches]
)
if "tgt_lang_id" in batches[0]:
batch["tgt_lang_id"] = straight_order(
[b["tgt_lang_id"] for b in batches]
)
return batch