in avhubert/hubert_dataset.py [0:0]
def collater(self, samples):
samples = [s for s in samples if s["id"] is not None]
if len(samples) == 0:
return {}
audio_source, video_source = [s["audio_source"] for s in samples], [s["video_source"] for s in samples]
if audio_source[0] is None:
audio_source = None
if video_source[0] is None:
video_source = None
if audio_source is not None:
audio_sizes = [len(s) for s in audio_source]
else:
audio_sizes = [len(s) for s in video_source]
if self.pad_audio:
audio_size = min(max(audio_sizes), self.max_sample_size)
else:
audio_size = min(min(audio_sizes), self.max_sample_size)
if audio_source is not None:
collated_audios, padding_mask, audio_starts = self.collater_audio(audio_source, audio_size)
else:
collated_audios, audio_starts = None, None
if video_source is not None:
collated_videos, padding_mask, audio_starts = self.collater_audio(video_source, audio_size, audio_starts)
else:
collated_videos = None
targets_by_label = [
[s["label_list"][i] for s in samples]
for i in range(self.num_labels)
]
targets_list, lengths_list, ntokens_list = self.collater_label(
targets_by_label, audio_size, audio_starts
)
source = {"audio": collated_audios, "video": collated_videos}
net_input = {"source": source, "padding_mask": padding_mask}
batch = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": net_input,
"utt_id": [s['fid'] for s in samples]
}
if self.single_target:
batch["target_lengths"] = lengths_list[0]
batch["ntokens"] = ntokens_list[0]
if self.is_s2s:
batch['target'], net_input['prev_output_tokens'] = targets_list[0][0], targets_list[0][1]
else:
batch["target"] = targets_list[0]
else:
batch["target_lengths_list"] = lengths_list
batch["ntokens_list"] = ntokens_list
batch["target_list"] = targets_list
return batch