def map_update_fn()

in occant_baselines/supervised/map_update.py [0:0]


def map_update_fn(ps_args):
    # Unpack args
    mapper = ps_args[0]
    mapper_rollouts = ps_args[1]
    optimizer = ps_args[2]
    num_update_batches = ps_args[3]
    batch_size = ps_args[4]
    freeze_projection_unit = ps_args[5]
    bias_factor = ps_args[6]
    occupancy_anticipator_type = ps_args[7]
    pose_loss_coef = ps_args[8]
    max_grad_norm = ps_args[9]
    label_id = ps_args[10]

    # Perform update
    losses = {
        "total_loss": 0,
        "mapping_loss": 0,
        "trans_loss": 0,
        "rot_loss": 0,
    }

    if isinstance(mapper, nn.DataParallel):
        mapper_config = mapper.module.config
    else:
        mapper_config = mapper.config

    img_mean = mapper_config.NORMALIZATION.img_mean
    img_std = mapper_config.NORMALIZATION.img_std
    start_time = time.time()
    # Debugging
    map_update_profile = {"data_sampling": 0.0, "pytorch_update": 0.0}
    for i in range(num_update_batches):
        start_time_sample = time.time()
        observations = mapper_rollouts.sample(batch_size)
        map_update_profile["data_sampling"] += time.time() - start_time_sample
        # Labels
        # Pose labels
        start_time_pyt = time.time()
        device = observations["pose_gt_at_t_1"].device

        pose_gt_at_t_1 = observations["pose_gt_at_t_1"]
        pose_gt_at_t = observations["pose_gt_at_t"]
        pose_at_t_1 = observations["pose_at_t_1"]
        pose_at_t = observations["pose_at_t"]
        dpose_gt = subtract_pose(pose_gt_at_t_1, pose_gt_at_t)  # (bs, 3)
        dpose_noisy = subtract_pose(pose_at_t_1, pose_at_t)  # (bs, 3)
        ddpose_gt = subtract_pose(dpose_noisy, dpose_gt)

        # Map labels
        pt_gt = observations[f"{label_id}_at_t"]  # (bs, V, V, 2)
        pt_gt = rearrange(pt_gt, "b h w c -> b c h w")  # (bs, 2, V, V)

        # Forward pass
        mapper_inputs = observations
        mapper_outputs = mapper(mapper_inputs, method_name="predict_deltas")
        pt_hat = mapper_outputs["pt"]

        # Compute losses
        # -------- mapping loss ---------
        mapping_loss = simple_mapping_loss_fn(pt_hat, pt_gt)
        if freeze_projection_unit:
            mapping_loss = mapping_loss.detach()

        if occupancy_anticipator_type == "rgb_model_v2":
            ego_map_gt = observations["ego_map_gt_at_t"]  # (bs, V, V, 2)
            ego_map_gt = rearrange(ego_map_gt, "b h w c -> b c h w")
            ego_map_hat = mapper_outputs["all_pu_outputs"]["depth_proj_estimate"]
            mapping_loss = mapping_loss + simple_mapping_loss_fn(
                ego_map_hat, ego_map_gt
            )

        all_pose_outputs = mapper_outputs["all_pose_outputs"]
        if all_pose_outputs is None:
            pose_estimation_loss = torch.zeros([0]).to(device).sum()
            trans_loss = torch.zeros([0]).to(device).sum()
            rot_loss = torch.zeros([0]).to(device).sum()
        else:
            pose_outputs = all_pose_outputs["pose_outputs"]
            pose_estimation_loss, trans_loss, rot_loss = 0, 0, 0
            n_outputs = len(list(pose_outputs.keys()))
            # The pose prediction outputs are performed for individual modalities,
            # and then weighted-averaged according to an ensemble MLP.
            # Here, the loss is computed for each modality as well as the ensemble.
            # Finally, it is averaged across the modalities.
            pose_label = ddpose_gt
            for _, dpose_hat in pose_outputs.items():
                curr_pose_losses = pose_loss_fn(dpose_hat, pose_label)
                pose_estimation_loss = pose_estimation_loss + curr_pose_losses[0]
                trans_loss = trans_loss + curr_pose_losses[1]
                rot_loss = rot_loss + curr_pose_losses[2]
            pose_estimation_loss = pose_estimation_loss / n_outputs
            trans_loss = trans_loss / n_outputs
            rot_loss = rot_loss / n_outputs

        optimizer.zero_grad()
        total_loss = mapping_loss + pose_estimation_loss * pose_loss_coef

        # Backward pass
        total_loss.backward()

        # Update
        nn.utils.clip_grad_norm_(mapper.parameters(), max_grad_norm)
        optimizer.step()

        losses["total_loss"] += total_loss.item()
        losses["mapping_loss"] += mapping_loss.item()
        losses["trans_loss"] += trans_loss.item()
        losses["rot_loss"] += rot_loss.item()

        map_update_profile["pytorch_update"] += time.time() - start_time_pyt
        time_per_step = (time.time() - start_time) / (60 * (i + 1))

    losses["pose_loss"] = losses["trans_loss"] + losses["rot_loss"]
    for k in losses.keys():
        losses[k] /= num_update_batches

    return losses