in muse/modeling_movq.py [0:0]
def forward(self, hidden_states, return_loss=False):
# reshape z -> (batch, height, width, channel) and flatten
hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
distances = self.compute_distances(hidden_states)
min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(hidden_states)
min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(hidden_states.shape)
# reshape to (batch, num_tokens)
min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
# compute loss for embedding
loss = None
if return_loss:
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - hidden_states) ** 2) + torch.mean(
(z_q - hidden_states.detach()) ** 2
)
else:
loss = torch.mean((z_q.detach() - hidden_states) ** 2) + self.beta * torch.mean(
(z_q - hidden_states.detach()) ** 2
)
# preserve gradients
z_q = hidden_states + (z_q - hidden_states).detach()
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, min_encoding_indices, loss