occant_baselines/rl/ans.py [643:690]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        else:
            # Condition 2 (a): The previous local goal was reached
            prev_goal_reached = (
                self.states["curr_dist2localgoal"] < self.goal_success_radius
            )
            # Condition 2 (b): The previous local goal is occupied.
            goals = self.states["curr_local_goals"].long().to(device)
            prev_gcells = global_map[
                torch.arange(0, goals.shape[0]).long(), :, goals[:, 1], goals[:, 0]
            ]
            prev_goal_occupied = (prev_gcells[:, 0] > self.config.thresh_obstacle) & (
                prev_gcells[:, 1] > self.config.thresh_explored
            )
            SAMPLE_LOCAL_GOAL_FLAGS = asnumpy(
                (prev_goal_reached | prev_goal_occupied).float()
            ).tolist()
        # Execute planner and compute local goals
        self._compute_plans_and_local_goals(
            global_map, self.states["curr_map_position"], SAMPLE_LOCAL_GOAL_FLAGS
        )
        # Update state variables to account for new local goals
        self.states["curr_dist2localgoal"] = self._compute_dist2localgoal(
            global_map,
            self.states["curr_map_position"],
            self.states["curr_local_goals"],
        )
        # Sample action with local policy
        local_masks = 1 - torch.Tensor(SAMPLE_LOCAL_GOAL_FLAGS).to(device).unsqueeze(1)
        recurrent_hidden_states = prev_state_estimates["recurrent_hidden_states"]
        relative_goals = self._compute_relative_local_goals(global_pose, M, s)
        local_policy_inputs = {
            "rgb_at_t": observations["rgb"],
            "goal_at_t": relative_goals,
            "t": ep_time,
        }
        outputs = self.local_policy.act(
            local_policy_inputs,
            recurrent_hidden_states,
            None,
            local_masks,
            deterministic=deterministic,
        )
        (
            local_value,
            local_action,
            local_action_log_probs,
            recurrent_hidden_states,
        ) = outputs
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



occant_baselines/rl/ans.py [1125:1172]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        else:
            # Condition 3: (a) The previous local goal was reached.
            prev_goal_reached = (
                self.states["curr_dist2localgoal"] < self.goal_success_radius
            )
            # Condition 3: (b) The previous local goal is occupied.
            goals = self.states["curr_local_goals"].long().to(device)
            prev_gcells = global_map[
                torch.arange(0, goals.shape[0]).long(), :, goals[:, 1], goals[:, 0]
            ]
            prev_goal_occupied = (prev_gcells[:, 0] > self.config.thresh_obstacle) & (
                prev_gcells[:, 1] > self.config.thresh_explored
            )
            SAMPLE_LOCAL_GOAL_FLAGS = asnumpy(
                (prev_goal_reached | prev_goal_occupied).float()
            ).tolist()
        # Execute planner and compute local goals
        self._compute_plans_and_local_goals(
            global_map, self.states["curr_map_position"], SAMPLE_LOCAL_GOAL_FLAGS
        )
        # Update state variables to account for new local goals
        self.states["curr_dist2localgoal"] = self._compute_dist2localgoal(
            global_map,
            self.states["curr_map_position"],
            self.states["curr_local_goals"],
        )
        # Sample action with local policy
        local_masks = 1 - torch.Tensor(SAMPLE_LOCAL_GOAL_FLAGS).to(device).unsqueeze(1)
        recurrent_hidden_states = prev_state_estimates["recurrent_hidden_states"]
        relative_goals = self._compute_relative_local_goals(global_pose, M, s)
        local_policy_inputs = {
            "rgb_at_t": observations["rgb"],
            "goal_at_t": relative_goals,
            "t": ep_time,
        }
        outputs = self.local_policy.act(
            local_policy_inputs,
            recurrent_hidden_states,
            None,
            local_masks,
            deterministic=deterministic,
        )
        (
            local_value,
            local_action,
            local_action_log_probs,
            recurrent_hidden_states,
        ) = outputs
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



