def forward()

in optimum/graphcore/models/wav2vec2/ipu_gumbel_vector_quantizer.py [0:0]


    def forward(self, hidden_states, gumbel_temperature=2.0, mask_time_indices=None):
        batch_size, sequence_length, hidden_size = hidden_states.shape

        # project to codevector dim
        hidden_states = self.weight_proj(hidden_states)
        hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)

        codevector_idx = hidden_states.argmax(dim=-1)
        hard_probs = torch.nn.functional.one_hot(codevector_idx.long(), num_classes=self.num_vars).view(
            batch_size * sequence_length, self.num_groups, -1
        )
        code_perplexity = self._compute_perplexity(hard_probs.float(), mask_time_indices)

        soft_probs = torch.softmax(
            hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(),
            dim=-1,
        )
        prob_perplexity = self._compute_perplexity(soft_probs, mask_time_indices)

        if self.training:
            # sample code vector probs via gumbel in differentiateable way
            codevector_probs = _ipu_gumbel_softmax(hidden_states.float(), tau=gumbel_temperature, hard=True).type_as(
                hidden_states
            )
        else:
            codevector_probs = hard_probs.type_as(hidden_states)

        codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
        codebook = self.codevectors[0, :, :]
        codebook = codebook.view(self.num_groups, self.num_vars, -1)
        codevectors = torch.bmm(codevector_probs.permute(1, 0, 2), codebook).permute(1, 0, 2)
        codevectors = codevectors.reshape(batch_size, sequence_length, -1)

        codevectors = codevectors.reshape(batch_size, sequence_length, -1)

        return codevectors, code_perplexity, prob_perplexity