in lerobot/common/policies/vqbet/vqbet_utils.py [0:0]
def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None):
"""
For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss.
First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize.
The residual value of each layer is fed to the next layer.
"""
num_quant, quant_dropout_multiple_of, return_loss, device = (
self.num_quantizers,
self.quantize_dropout_multiple_of,
(indices is not None),
x.device,
)
x = self.project_in(x)
assert not (self.accept_image_fmap and (indices is not None))
quantized_out = 0.0
residual = x
all_losses = []
all_indices = []
if return_loss:
assert not torch.any(indices == -1), (
"some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
)
ce_losses = []
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
# sample a layer index at which to dropout further residual quantization
# also prepare null indices and loss
if should_quantize_dropout:
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = (
ceil((rand_quantize_dropout_index + 1) / quant_dropout_multiple_of)
* quant_dropout_multiple_of
- 1
)
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long)
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
# go through the layers
for quantizer_index, layer in enumerate(self.layers):
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
all_indices.append(null_indices)
all_losses.append(null_loss)
continue
layer_indices = None
if return_loss:
layer_indices = indices[..., quantizer_index]
quantized, *rest = layer(
residual,
indices=layer_indices,
sample_codebook_temp=sample_codebook_temp,
freeze_codebook=self.freeze_codebook,
)
residual = residual - quantized.detach()
quantized_out = quantized_out + quantized
if return_loss:
ce_loss = rest[0]
ce_losses.append(ce_loss)
continue
embed_indices, loss = rest
all_indices.append(embed_indices)
all_losses.append(loss)
# project out, if needed
quantized_out = self.project_out(quantized_out)
# whether to early return the cross entropy loss
if return_loss:
return quantized_out, sum(ce_losses)
# stack all losses and indices
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))
ret = (quantized_out, all_indices, all_losses)
if return_all_codes:
# whether to return all codes from all codebooks across layers
all_codes = self.get_codebook_vector_from_indices(all_indices)
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
ret = (*ret, all_codes)
return ret