def __call__()

in ultravox/model/ultravox_processing.py [0:0]


    def __call__(self, features, *args, **kwargs):
        audio_values = [x for f in features for x in f.pop("audio_values", [])]
        audio_lens = [x for f in features for x in f.pop("audio_lens", [])]
        audio_token_len = [x for f in features for x in f.pop("audio_token_len", [])]
        audio_token_start_idx = [
            x for f in features for x in f.pop("audio_token_start_idx", [])
        ]

        if self.include_alt_fields:
            # these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method
            alt_features = [
                {
                    "input_ids": f.pop("alt_input_ids"),
                    "attention_mask": f.pop("alt_attention_mask"),
                    "labels": f.pop("alt_labels"),
                }
                for f in features
            ]

        batch = super().__call__(features, *args, **kwargs)
        if self.include_alt_fields:
            alt_batch = super().__call__(alt_features, *args, **kwargs)
            batch["alt_input_ids"] = alt_batch["input_ids"]
            batch["alt_attention_mask"] = alt_batch["attention_mask"]
            batch["alt_labels"] = alt_batch["labels"]

        batch["audio_token_start_idx"] = torch.stack(audio_token_start_idx)
        batch["audio_lens"] = torch.stack(audio_lens)
        batch["audio_token_len"] = torch.stack(audio_token_len)

        # Pad the last dimension of all audio_values to the same length, with 0s on the right.
        if audio_values:
            max_len = max([x.shape[-1] for x in audio_values])
            batch["audio_values"] = torch.stack(
                [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
            )
            if self.tokenizer.padding_side == "left":
                input_ids_lens = torch.LongTensor(
                    [f["input_ids"].shape[-1] for f in features]
                )
                displacement = batch["input_ids"].shape[-1] - input_ids_lens
                displacement = displacement.repeat_interleave(
                    batch["audio_batch_size"].squeeze(-1)
                )
                batch["audio_token_start_idx"] += displacement.to(
                    batch["audio_token_start_idx"].device
                )
        return batch