in optimum/graphcore/models/wav2vec2/ipu_gumbel_vector_quantizer.py [0:0]
def forward(self, hidden_states, gumbel_temperature=2.0, mask_time_indices=None):
batch_size, sequence_length, hidden_size = hidden_states.shape
# project to codevector dim
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
codevector_idx = hidden_states.argmax(dim=-1)
hard_probs = torch.nn.functional.one_hot(codevector_idx.long(), num_classes=self.num_vars).view(
batch_size * sequence_length, self.num_groups, -1
)
code_perplexity = self._compute_perplexity(hard_probs.float(), mask_time_indices)
soft_probs = torch.softmax(
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(),
dim=-1,
)
prob_perplexity = self._compute_perplexity(soft_probs, mask_time_indices)
if self.training:
# sample code vector probs via gumbel in differentiateable way
codevector_probs = _ipu_gumbel_softmax(hidden_states.float(), tau=gumbel_temperature, hard=True).type_as(
hidden_states
)
else:
codevector_probs = hard_probs.type_as(hidden_states)
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
codebook = self.codevectors[0, :, :]
codebook = codebook.view(self.num_groups, self.num_vars, -1)
codevectors = torch.bmm(codevector_probs.permute(1, 0, 2), codebook).permute(1, 0, 2)
codevectors = codevectors.reshape(batch_size, sequence_length, -1)
codevectors = codevectors.reshape(batch_size, sequence_length, -1)
return codevectors, code_perplexity, prob_perplexity