def collater()

in fairseq/data/multilingual/sampled_multi_dataset.py [0:0]


    def collater(self, samples, **extra_args):
        """Merge a list of samples to form a mini-batch."""
        if len(samples) == 0:
            return None
        if self.collate_format == "ordered_dict":
            collect_samples = [[] for _ in range(len(self.datasets))]
            for (i, sample) in samples:
                collect_samples[i].append(sample)
            batch = OrderedDict(
                [
                    (self.keys[i], dataset.collater(collect_samples[i]))
                    for i, (key, dataset) in enumerate(zip(self.keys, self.datasets))
                    if len(collect_samples[i]) > 0
                ]
            )
        elif self.shared_collater:
            batch = self.datasets[0].collater([s for _, s in samples])
        else:
            samples_dict = defaultdict(list)
            pad_to_length = (
                defaultdict(int)
                if "pad_to_length" not in extra_args
                else extra_args["pad_to_length"]
            )
            for ds_idx, s in samples:
                pad_to_length["source"] = max(
                    pad_to_length["source"], s["source"].size(0)
                )
                if s["target"] is not None:
                    pad_to_length["target"] = max(
                        pad_to_length["target"], s["target"].size(0)
                    )
                samples_dict[ds_idx].append(s)
            batches = [
                self.datasets[i].collater(samples_dict[i], pad_to_length=pad_to_length)
                for i in range(len(self.datasets))
                if len(samples_dict[i]) > 0
            ]

            def straight_data(tensors):
                batch = torch.cat(tensors, dim=0)
                return batch

            src_lengths = straight_data(
                [b["net_input"]["src_lengths"] for b in batches]
            )
            src_lengths, sort_order = src_lengths.sort(descending=True)

            def straight_order(tensors):
                batch = straight_data(tensors)
                return batch.index_select(0, sort_order)

            batch = {
                "id": straight_order([b["id"] for b in batches]),
                "nsentences": sum(b["nsentences"] for b in batches),
                "ntokens": sum(b["ntokens"] for b in batches),
                "net_input": {
                    "src_tokens": straight_order(
                        [b["net_input"]["src_tokens"] for b in batches]
                    ),
                    "src_lengths": src_lengths,
                },
                "target": straight_order([b["target"] for b in batches])
                if batches[0]["target"] is not None
                else None,
            }
            if "prev_output_tokens" in batches[0]["net_input"]:
                batch["net_input"]["prev_output_tokens"] = straight_order(
                    [b["net_input"]["prev_output_tokens"] for b in batches]
                )
            if "src_lang_id" in batches[0]["net_input"]:
                batch["net_input"]["src_lang_id"] = straight_order(
                    [b["net_input"]["src_lang_id"] for b in batches]
                )
            if "tgt_lang_id" in batches[0]:
                batch["tgt_lang_id"] = straight_order(
                    [b["tgt_lang_id"] for b in batches]
                )
        return batch