in occant_baselines/supervised/map_update.py [0:0]
def map_update_fn(ps_args):
# Unpack args
mapper = ps_args[0]
mapper_rollouts = ps_args[1]
optimizer = ps_args[2]
num_update_batches = ps_args[3]
batch_size = ps_args[4]
freeze_projection_unit = ps_args[5]
bias_factor = ps_args[6]
occupancy_anticipator_type = ps_args[7]
pose_loss_coef = ps_args[8]
max_grad_norm = ps_args[9]
label_id = ps_args[10]
# Perform update
losses = {
"total_loss": 0,
"mapping_loss": 0,
"trans_loss": 0,
"rot_loss": 0,
}
if isinstance(mapper, nn.DataParallel):
mapper_config = mapper.module.config
else:
mapper_config = mapper.config
img_mean = mapper_config.NORMALIZATION.img_mean
img_std = mapper_config.NORMALIZATION.img_std
start_time = time.time()
# Debugging
map_update_profile = {"data_sampling": 0.0, "pytorch_update": 0.0}
for i in range(num_update_batches):
start_time_sample = time.time()
observations = mapper_rollouts.sample(batch_size)
map_update_profile["data_sampling"] += time.time() - start_time_sample
# Labels
# Pose labels
start_time_pyt = time.time()
device = observations["pose_gt_at_t_1"].device
pose_gt_at_t_1 = observations["pose_gt_at_t_1"]
pose_gt_at_t = observations["pose_gt_at_t"]
pose_at_t_1 = observations["pose_at_t_1"]
pose_at_t = observations["pose_at_t"]
dpose_gt = subtract_pose(pose_gt_at_t_1, pose_gt_at_t) # (bs, 3)
dpose_noisy = subtract_pose(pose_at_t_1, pose_at_t) # (bs, 3)
ddpose_gt = subtract_pose(dpose_noisy, dpose_gt)
# Map labels
pt_gt = observations[f"{label_id}_at_t"] # (bs, V, V, 2)
pt_gt = rearrange(pt_gt, "b h w c -> b c h w") # (bs, 2, V, V)
# Forward pass
mapper_inputs = observations
mapper_outputs = mapper(mapper_inputs, method_name="predict_deltas")
pt_hat = mapper_outputs["pt"]
# Compute losses
# -------- mapping loss ---------
mapping_loss = simple_mapping_loss_fn(pt_hat, pt_gt)
if freeze_projection_unit:
mapping_loss = mapping_loss.detach()
if occupancy_anticipator_type == "rgb_model_v2":
ego_map_gt = observations["ego_map_gt_at_t"] # (bs, V, V, 2)
ego_map_gt = rearrange(ego_map_gt, "b h w c -> b c h w")
ego_map_hat = mapper_outputs["all_pu_outputs"]["depth_proj_estimate"]
mapping_loss = mapping_loss + simple_mapping_loss_fn(
ego_map_hat, ego_map_gt
)
all_pose_outputs = mapper_outputs["all_pose_outputs"]
if all_pose_outputs is None:
pose_estimation_loss = torch.zeros([0]).to(device).sum()
trans_loss = torch.zeros([0]).to(device).sum()
rot_loss = torch.zeros([0]).to(device).sum()
else:
pose_outputs = all_pose_outputs["pose_outputs"]
pose_estimation_loss, trans_loss, rot_loss = 0, 0, 0
n_outputs = len(list(pose_outputs.keys()))
# The pose prediction outputs are performed for individual modalities,
# and then weighted-averaged according to an ensemble MLP.
# Here, the loss is computed for each modality as well as the ensemble.
# Finally, it is averaged across the modalities.
pose_label = ddpose_gt
for _, dpose_hat in pose_outputs.items():
curr_pose_losses = pose_loss_fn(dpose_hat, pose_label)
pose_estimation_loss = pose_estimation_loss + curr_pose_losses[0]
trans_loss = trans_loss + curr_pose_losses[1]
rot_loss = rot_loss + curr_pose_losses[2]
pose_estimation_loss = pose_estimation_loss / n_outputs
trans_loss = trans_loss / n_outputs
rot_loss = rot_loss / n_outputs
optimizer.zero_grad()
total_loss = mapping_loss + pose_estimation_loss * pose_loss_coef
# Backward pass
total_loss.backward()
# Update
nn.utils.clip_grad_norm_(mapper.parameters(), max_grad_norm)
optimizer.step()
losses["total_loss"] += total_loss.item()
losses["mapping_loss"] += mapping_loss.item()
losses["trans_loss"] += trans_loss.item()
losses["rot_loss"] += rot_loss.item()
map_update_profile["pytorch_update"] += time.time() - start_time_pyt
time_per_step = (time.time() - start_time) / (60 * (i + 1))
losses["pose_loss"] = losses["trans_loss"] + losses["rot_loss"]
for k in losses.keys():
losses[k] /= num_update_batches
return losses