def task_graph()

in MTRF/r3l/r3l/r3l_envs/inhand_env/basket.py [0:0]


    def task_graph(self):
        Phase = self.Phase

        # Drop the object if we were just in the midair reposition phase
        if self.phase == Phase.REPOSITION_MIDAIR.value:
            self._release_object()

        # Construct a new adjacency matrix
        self.task_adj_matrix = np.zeros((self.num_phases, self.num_phases))

        # Collect environment obs and reward dicts to determine transitions
        repos_env = self._envs[Phase.REPOSITION.value]
        repos_obs = repos_env.get_obs_dict()

        repos_xy_dist = repos_obs["object_to_target_xy_distance"]
        object_xyz = repos_obs["object_xyz"]
        basket_env = self._envs[Phase.REPOSITION_MIDAIR.value]
        # obj_to_basket_dist = basket_env.get_obs_dict()["object_to_target_xyz_distance"]

        if object_xyz[2] <= self.Z_DISTANCE_THRESHOLD and repos_xy_dist >= self.DISTANCE_THRESHOLD:
            """
            If the object is on the table (not picked up) and the object is more
            than `DISTANCE_THRESHOLD` away from the center
            -> PERTURB        (with probability p_repos)
            -> REPOSITION     (with probability 1 - p_repos)
            """
            # All other phases should transition to reposition with probability 1
            self.task_adj_matrix[Phase.PERTURB.value][Phase.REPOSITION.value] = 1.0
            self.task_adj_matrix[Phase.PICKUP.value][Phase.REPOSITION.value] = 1.0
            self.task_adj_matrix[Phase.REPOSITION_MIDAIR.value][Phase.REPOSITION.value] = 1.0

            if not self._perturb_off:
                # Reposition -> Perturb if not successful
                self.task_adj_matrix[Phase.REPOSITION.value][Phase.PERTURB.value] = 1.0
            else:
                self.task_adj_matrix[Phase.REPOSITION.value][Phase.REPOSITION.value] = 1.0
        elif object_xyz[2] <= self.Z_DISTANCE_THRESHOLD and repos_xy_dist < self.DISTANCE_THRESHOLD:
            """
            Otherwise, transition into the pickup task.
            Condition should be that the object is within `DISTANCE_THRESHOLD` in
            xy distance, and within `ANGLE_THRESHOLD` in circle distance.
            """
            # All phases transition to reorient with probability 1, including self-loop
            self.task_adj_matrix[Phase.PERTURB.value][Phase.PICKUP.value] = 1.0
            self.task_adj_matrix[Phase.REPOSITION.value][Phase.PICKUP.value] = 1.0
            self.task_adj_matrix[Phase.PICKUP.value][Phase.PICKUP.value] = 1.0
            self.task_adj_matrix[Phase.REPOSITION_MIDAIR.value][Phase.PICKUP.value] = 1.0
        elif object_xyz[2] > self.Z_DISTANCE_THRESHOLD:
            self.task_adj_matrix[Phase.PERTURB.value][Phase.REPOSITION_MIDAIR.value] = 1.0
            self.task_adj_matrix[Phase.REPOSITION.value][Phase.REPOSITION_MIDAIR.value] = 1.0
            self.task_adj_matrix[Phase.PICKUP.value][Phase.REPOSITION_MIDAIR.value] = 1.0
            self.task_adj_matrix[Phase.REPOSITION_MIDAIR.value][Phase.REPOSITION_MIDAIR.value] = 1.0
        else:
            print("Should not have reached this condition with: "
                  f"repos_xy_dist={repos_xy_dist}, object_xyz={object_xyz}")
            raise NotImplementedError

        if self._verbose:
            print("TASK ADJACENCY MATRIX:\n", self.task_adj_matrix)

        # The `self.phase` row of the adjacency matrix gives you next-phase
        # transition probabilities.
        task_adj_list = self.task_adj_matrix[self.phase]
        assert np.sum(task_adj_list) == 1
        if self._random_task_graph:
            # Removed the probability from the choice list
            next_phase = self.np_random.choice(self.num_phases)
        else:
            next_phase = self.np_random.choice(self.num_phases, p=task_adj_list)

        return next_phase