in janus/models/vq_model.py [0:0]
def get_codebook_entry(self, indices, shape=None, channel_first=True):
# shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
if self.l2_norm:
embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
else:
embedding = self.embedding.weight
z_q = embedding[indices] # (b*h*w, c)
if shape is not None:
if channel_first:
z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
else:
z_q = z_q.view(shape)
return z_q