in lerobot/common/policies/act/modeling_act.py [0:0]
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure:
{
[robot_state_feature] (optional): (B, state_dim) batch of robot states.
[image_features]: (B, n_cameras, C, H, W) batch of images.
AND/OR
[env_state_feature]: (B, env_dim) batch of environment states.
[action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
}
Returns:
(B, chunk_size, action_dim) batch of action sequences
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
latent dimension.
"""
if self.config.use_vae and self.training:
assert "action" in batch, (
"actions must be provided when using the variational objective in training mode."
)
if "observation.images" in batch:
batch_size = batch["observation.images"][0].shape[0]
else:
batch_size = batch["observation.environment_state"].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
# Prepare fixed positional embedding.
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
# Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the
# sequence depending whether we use the input states or not (cls and robot state)
# False means not a padding token.
cls_joint_is_pad = torch.full(
(batch_size, 2 if self.config.robot_state_feature else 1),
False,
device=batch["observation.state"].device,
)
key_padding_mask = torch.cat(
[cls_joint_is_pad, batch["action_is_pad"]], axis=1
) # (bs, seq+1 or 2)
# Forward pass through VAE encoder to get the latent PDF parameters.
cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2),
pos_embed=pos_embed.permute(1, 0, 2),
key_padding_mask=key_padding_mask,
)[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.config.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation.
log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]
# Sample the latent with the reparameterization trick.
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
else:
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
# Prepare transformer encoder inputs.
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
# Robot state token.
if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
# Environment state token.
if self.config.env_state_feature:
encoder_in_tokens.append(
self.encoder_env_state_input_proj(batch["observation.environment_state"])
)
# Camera observation features and positional embeddings.
if self.config.image_features:
all_cam_features = []
all_cam_pos_embeds = []
# For a list of images, the H and W may vary but H*W is constant.
for img in batch["observation.images"]:
cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features)
# Rearrange features to (sequence, batch, dim).
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
# Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)
# Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed)
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
decoder_in = torch.zeros(
(self.config.chunk_size, batch_size, self.config.dim_model),
dtype=encoder_in_pos_embed.dtype,
device=encoder_in_pos_embed.device,
)
decoder_out = self.decoder(
decoder_in,
encoder_out,
encoder_pos_embed=encoder_in_pos_embed,
decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
)
# Move back to (B, S, C).
decoder_out = decoder_out.transpose(0, 1)
actions = self.action_head(decoder_out)
return actions, (mu, log_sigma_x2)