in lerobot/common/policies/vqbet/vqbet_utils.py [0:0]
def update_affine(self, data, embed, mask=None):
assert self.affine_param
var_fn = partial(torch.var, unbiased=False)
# calculate codebook mean and variance
embed = rearrange(embed, "h ... d -> h (...) d")
if self.training:
self.update_with_decay(
"codebook_mean",
reduce(embed, "h n d -> h 1 d", "mean"),
self.affine_param_codebook_decay,
)
self.update_with_decay(
"codebook_variance",
reduce(embed, "h n d -> h 1 d", var_fn),
self.affine_param_codebook_decay,
)
# prepare batch data, which depends on whether it has masking
data = rearrange(data, "h ... d -> h (...) d")
if mask is not None:
c = data.shape[0]
data = rearrange(data[mask], "(c n) d -> c n d", c=c)
# calculate batch mean and variance
if not self.sync_affine_param:
self.update_with_decay(
"batch_mean",
reduce(data, "h n d -> h 1 d", "mean"),
self.affine_param_batch_decay,
)
self.update_with_decay(
"batch_variance",
reduce(data, "h n d -> h 1 d", var_fn),
self.affine_param_batch_decay,
)
return
num_vectors, device, dtype = data.shape[-2], data.device, data.dtype
# number of vectors, for denominator
num_vectors = torch.tensor([num_vectors], device=device, dtype=dtype)
distributed.all_reduce(num_vectors)
# calculate distributed mean
batch_sum = reduce(data, "h n d -> h 1 d", "sum")
distributed.all_reduce(batch_sum)
batch_mean = batch_sum / num_vectors
self.update_with_decay("batch_mean", batch_mean, self.affine_param_batch_decay)
# calculate distributed variance
variance_number = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
distributed.all_reduce(variance_number)
batch_variance = variance_number / num_vectors
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)