def __call__()

in modelling/src/neuraldb/dataset/data_collator_seq2seq.py [0:0]


    def __call__(self, features: Iterable[dict]):
        metadata = [
            record["metadata"] if "metadata" in record else {} for record in features
        ]

        labels = (
            [feature["labels"] for feature in features]
            if "labels" in features[0].keys()
            else None
        )
        # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
        # same length to return tensors.
        if labels is not None:
            max_label_length = max(len(lab) for lab in labels)
            padding_side = self.tokenizer.padding_side
            for feature in features:
                remainder = [self.label_pad_token_id] * (
                    max_label_length - len(feature["labels"])
                )
                feature["labels"] = (
                    feature["labels"] + remainder
                    if padding_side == "right"
                    else remainder + feature["labels"]
                )

        if "input_ids" in features[0]:
            master_features = self.tokenizer.pad(
                [
                    {
                        k: v
                        for k, v in feature.items()
                        if k not in {"metadata", "global_attention_mask"}
                    }
                    for feature in features
                ],
                padding=self.padding,
                max_length=self.max_length,
                pad_to_multiple_of=self.pad_to_multiple_of,
                return_tensors="pt",
            )

        if "context_ids" in features[0]:
            virtual_features = []
            lengths = []

            for feature in features:
                assert len(feature["context_ids"]) == len(feature["context_mask"])
                virtual_features.extend(
                    {
                        "input_ids": context,
                        "attention_mask": attention,
                        "labels": feature["labels"],
                    }
                    for context, attention in zip(
                        feature["context_ids"], feature["context_mask"]
                    )
                )
                lengths.append(len(feature["context_ids"]))

            master_features = self.tokenizer.pad(
                virtual_features,
                padding=self.padding,
                max_length=self.max_length,
                pad_to_multiple_of=self.pad_to_multiple_of,
                return_tensors="pt",
            )

            # master_features["lengths"] = lengths
            master_features["context_ids"] = []
            master_features["context_mask"] = []
            master_features["labels"] = torch.stack(
                [master_features["labels"][i - 1] for i in np.cumsum(lengths)], dim=0
            )

            previous = 0
            for length in lengths:
                end = previous + length

                master_features["context_ids"].append(
                    master_features["input_ids"][previous:end]
                )  # noqa: E501
                master_features["context_mask"].append(
                    master_features["attention_mask"][previous:end]
                )  # noqa: E501
                previous += length

            master_features["input_ids"] = torch.zeros((len(lengths), 1))
            del master_features["attention_mask"]

        if "global_attention_mask" in features[0]:
            additional_features = self.tokenizer.pad(
                [
                    {
                        "input_ids": feature["input_ids"],
                        "attention_mask": feature["global_attention_mask"],
                    }
                    for feature in features
                ],
                padding=self.padding,
                max_length=self.max_length,
                pad_to_multiple_of=self.pad_to_multiple_of,
                return_tensors="pt",
            )
            master_features["global_attention_mask"] = additional_features[
                "attention_mask"
            ]

        # prepare decoder_input_ids
        if (
            "labels" in master_features
            and self.model is not None
            and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
        ):
            decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
                labels=master_features["labels"]
            )
            master_features["decoder_input_ids"] = decoder_input_ids

        if any(meta for meta in metadata):
            master_features["metadata"] = metadata

        return master_features