in ultravox/model/ultravox_processing.py [0:0]
def __call__(self, features, *args, **kwargs):
audio_values = [x for f in features for x in f.pop("audio_values", [])]
audio_lens = [x for f in features for x in f.pop("audio_lens", [])]
audio_token_len = [x for f in features for x in f.pop("audio_token_len", [])]
audio_token_start_idx = [
x for f in features for x in f.pop("audio_token_start_idx", [])
]
if self.include_alt_fields:
# these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method
alt_features = [
{
"input_ids": f.pop("alt_input_ids"),
"attention_mask": f.pop("alt_attention_mask"),
"labels": f.pop("alt_labels"),
}
for f in features
]
batch = super().__call__(features, *args, **kwargs)
if self.include_alt_fields:
alt_batch = super().__call__(alt_features, *args, **kwargs)
batch["alt_input_ids"] = alt_batch["input_ids"]
batch["alt_attention_mask"] = alt_batch["attention_mask"]
batch["alt_labels"] = alt_batch["labels"]
batch["audio_token_start_idx"] = torch.stack(audio_token_start_idx)
batch["audio_lens"] = torch.stack(audio_lens)
batch["audio_token_len"] = torch.stack(audio_token_len)
# Pad the last dimension of all audio_values to the same length, with 0s on the right.
if audio_values:
max_len = max([x.shape[-1] for x in audio_values])
batch["audio_values"] = torch.stack(
[F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
)
if self.tokenizer.padding_side == "left":
input_ids_lens = torch.LongTensor(
[f["input_ids"].shape[-1] for f in features]
)
displacement = batch["input_ids"].shape[-1] - input_ids_lens
displacement = displacement.repeat_interleave(
batch["audio_batch_size"].squeeze(-1)
)
batch["audio_token_start_idx"] += displacement.to(
batch["audio_token_start_idx"].device
)
return batch