in lerobot/common/policies/tdmpc/modeling_tdmpc.py [0:0]
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats.
"""
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
batch = self.normalize_targets(batch)
info = {}
# (b, t) -> (t, b)
for key in batch:
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
action = batch[ACTION] # (t, b, action_dim)
reward = batch[REWARD] # (t, b)
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# Apply random image augmentations.
if self.config.image_features and self.config.max_random_shift_ratio > 0:
observations[OBS_IMAGE] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
observations[OBS_IMAGE],
)
# Get the current observation for predicting trajectories, and all future observations for use in
# the latent consistency loss and TD loss.
current_observation, next_observations = {}, {}
for k in observations:
current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:]
horizon, batch_size = next_observations[
OBS_IMAGE if self.config.image_features else OBS_ENV_STATE
].shape[:2]
# Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device)
for t in range(horizon):
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
# Compute Q and V value predictions based on the latent rollout.
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
v_preds = self.model.V(z_preds[:-1])
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
# Compute various targets with stopgrad.
with torch.no_grad():
# Latent state consistency targets.
z_targets = self.model_target.encode(next_observations)
# State-action value targets (or TD targets) as in eqn 3 of the FOWM. Unlike TD-MPC which uses the
# learned state-action value function in conjunction with the learned policy: Q(z, π(z)), FOWM
# uses a learned state value function: V(z). This means the TD targets only depend on in-sample
# actions (not actions estimated by π).
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
# and the FOWM paper.
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
# are using them to compute loss for V.
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
# Compute losses.
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
# future have less impact on the loss. Note: unsqueeze will let us broadcast to (seq, batch).
temporal_loss_coeffs = torch.pow(
self.config.temporal_decay_coeff, torch.arange(horizon, device=device)
).unsqueeze(-1)
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
# predicted from the (target model's) observation encoder.
consistency_loss = (
(
temporal_loss_coeffs
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
# `z_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# `z_targets` depends on the next observation.
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
)
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
# rewards.
reward_loss = (
(
temporal_loss_coeffs
* F.mse_loss(reward_preds, reward, reduction="none")
* ~batch["next.reward_is_pad"]
# `reward_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
)
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
q_value_loss = (
(
temporal_loss_coeffs
* F.mse_loss(
q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
)
# Compute state value loss as in eqn 3 of FOWM.
diff = v_targets - v_preds
# Expectile loss penalizes:
# - `v_preds < v_targets` with weighting `expectile_weight`
# - `v_preds >= v_targets` with weighting `1 - expectile_weight`
raw_v_value_loss = torch.where(
diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight)
) * (diff**2)
v_value_loss = (
(
temporal_loss_coeffs
* raw_v_value_loss
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
)
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
# We won't need these gradients again so detach.
z_preds = z_preds.detach()
# Use stopgrad for the advantage calculation.
with torch.no_grad():
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
z_preds[:-1]
)
info["advantage"] = advantage[0]
# (t, b)
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
# Calculate the MSE between the actions and the action predictions.
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
# gaussian) and sums over the action dimension. Computing the (negative) log probability amounts to
# multiplying the MSE by 0.5 and adding a constant offset (the log(2*pi)/2 term, times the action
# dimension). Here we drop the constant offset as it doesn't change the optimization step, and we drop
# the 0.5 as we instead make a configuration parameter for it (see below where we compute the total
# loss).
mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
# other losses.
# TODO(alexander-soare): Take the sum over the temporal dimension and check that training still works
# as well as expected.
pi_loss = (
exp_advantage
* mse
* temporal_loss_coeffs
# `action_preds` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
).mean()
loss = (
self.config.consistency_coeff * consistency_loss
+ self.config.reward_coeff * reward_loss
+ self.config.value_coeff * q_value_loss
+ self.config.value_coeff * v_value_loss
+ self.config.pi_coeff * pi_loss
)
info.update(
{
"consistency_loss": consistency_loss.item(),
"reward_loss": reward_loss.item(),
"Q_value_loss": q_value_loss.item(),
"V_value_loss": v_value_loss.item(),
"pi_loss": pi_loss.item(),
"sum_loss": loss.item() * self.config.horizon,
}
)
# Undo (b, t) -> (t, b).
for key in batch:
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
return loss, info