in avhubert/hubert.py [0:0]
def apply_input_mask(self, x, padding_mask, target_list):
B, C, T = x.shape[:3]
is_audio = True if len(x.shape) == 3 else False
if is_audio:
mask_prob, mask_length = self.mask_prob_audio, self.mask_length_audio
else:
mask_prob, mask_length = self.mask_prob_image, self.mask_length_image
if mask_prob > 0:
mask_indices, starts, ends, batch_indexes = compute_mask_indices(
(B, T),
padding_mask,
mask_prob,
mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices_np = mask_indices
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x = x.transpose(1, 2).contiguous() # [B, T, C, H, W]
if B == 1:
x[mask_indices] = 0
elif is_audio:
x[mask_indices] = self.mask_emb
elif self.selection_type == 'same_other_seq':
perm = (torch.arange(B) + torch.randint(low=1, high=B, size=(1,))) % B
x_perm = x[perm]
x[mask_indices] = x_perm[mask_indices]
elif self.selection_type == 'same_seq':
batch_indexes_, other_indexes = [], []
for batch_index, start, end in zip(batch_indexes, starts, ends):
length = end-start
other_start = np.setdiff1d(np.arange(T), np.arange(max(0, start-length), end))
if len(other_start) > 0:
other_start = np.random.choice(other_start, size=1)
else:
other_start = 0
other_end = other_start + length
other_indexes.append(np.arange(other_start, other_end).clip(max=T-1))
batch_indexes_.append(np.zeros([length], dtype=np.int64)+batch_index)
batch_indexes, other_indexes = np.concatenate(batch_indexes_), np.concatenate(other_indexes)
x[mask_indices] = x[batch_indexes, other_indexes]
x = x.transpose(1, 2).contiguous()
else:
mask_indices = None
if self.mask_channel_prob > 0:
logger.info(f"No mask channel prob for input masking")
return x, mask_indices