in lerobot/common/policies/vqbet/vqbet_utils.py [0:0]
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
needs_codebook_dim = x.ndim < 4
sample_codebook_temp = (
sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
)
x = x.float()
if needs_codebook_dim:
x = rearrange(x, "... -> 1 ...")
flatten, ps = pack_one(x, "h * d")
if mask is not None:
mask = repeat(
mask,
"b n -> c (b h n)",
c=flatten.shape[0],
h=flatten.shape[-2] // (mask.shape[0] * mask.shape[1]),
)
self.init_embed_(flatten, mask=mask)
if self.affine_param:
self.update_affine(flatten, self.embed, mask=mask)
embed = self.embed if self.learnable_codebook else self.embed.detach()
if self.affine_param:
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
dist = -cdist(flatten, embed)
embed_ind, embed_onehot = self.gumbel_sample(
dist, dim=-1, temperature=sample_codebook_temp, training=self.training
)
embed_ind = unpack_one(embed_ind, ps, "h *")
if self.training:
unpacked_onehot = unpack_one(embed_onehot, ps, "h * c")
quantize = einsum("h b n c, h c d -> h b n d", unpacked_onehot, embed)
else:
quantize = batched_embedding(embed_ind, embed)
if self.training and self.ema_update and not freeze_codebook:
if self.affine_param:
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
if mask is not None:
embed_onehot[~mask] = 0.0
cluster_size = embed_onehot.sum(dim=1)
self.all_reduce_fn(cluster_size)
ema_inplace(self.cluster_size.data, cluster_size, self.decay)
embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot)
self.all_reduce_fn(embed_sum.contiguous())
ema_inplace(self.embed_avg.data, embed_sum, self.decay)
cluster_size = laplace_smoothing(
self.cluster_size, self.codebook_size, self.eps
) * self.cluster_size.sum(dim=-1, keepdim=True)
embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1")
self.embed.data.copy_(embed_normalized)
self.expire_codes_(x)
if needs_codebook_dim:
quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind))
dist = unpack_one(dist, ps, "h * d")
return quantize, embed_ind, dist