def forward()

in muse/modeling_movq.py [0:0]


    def forward(self, hidden_states, return_loss=False):
        # reshape z -> (batch, height, width, channel) and flatten
        hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()

        distances = self.compute_distances(hidden_states)
        min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1)
        min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(hidden_states)
        min_encodings.scatter_(1, min_encoding_indices, 1)

        # get quantized latent vectors
        z_q = torch.matmul(min_encodings, self.embedding.weight).view(hidden_states.shape)

        # reshape to (batch, num_tokens)
        min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)

        # compute loss for embedding
        loss = None
        if return_loss:
            if not self.legacy:
                loss = self.beta * torch.mean((z_q.detach() - hidden_states) ** 2) + torch.mean(
                    (z_q - hidden_states.detach()) ** 2
                )
            else:
                loss = torch.mean((z_q.detach() - hidden_states) ** 2) + self.beta * torch.mean(
                    (z_q - hidden_states.detach()) ** 2
                )

            # preserve gradients
            z_q = hidden_states + (z_q - hidden_states).detach()

        # reshape back to match original input shape
        z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q, min_encoding_indices, loss