def get_action()

in activemri/baselines/simple_baselines.py [0:0]


    def get_action(self, obs: Dict[str, Any], **_kwargs) -> List[int]:
        """Returns a one-step greedy action maximizing reconstruction score.

        Args:
            obs(dict(str, any)): As returned by :class:`activemri.envs.ActiveMRIEnv`.

        Returns:
            list(int): A list of k-space column indices, one per batch element in
                the observation, equal to the action that maximizes reconstruction score
                (e.g, SSIM or negative MSE).
        """
        mask = obs["mask"]
        batch_size = mask.shape[0]
        all_action_lists = []
        for i in range(batch_size):
            available_actions = mask[i].logical_not().nonzero().squeeze().tolist()
            self.rng.shuffle(available_actions)
            if len(available_actions) < self.num_samples:
                # Add dummy actions to try if num of samples is higher than the
                # number of inactive columns in this mask
                available_actions.extend(
                    [0] * (self.num_samples - len(available_actions))
                )
            all_action_lists.append(available_actions)

        all_scores = np.zeros((batch_size, self.num_samples))
        for i in range(self.num_samples):
            batch_action_to_try = [action_list[i] for action_list in all_action_lists]
            obs, new_score = self.env.try_action(batch_action_to_try)
            all_scores[:, i] = new_score[self.metric]
        if self.metric in ["mse", "nmse"]:
            all_scores *= -1
        else:
            assert self.metric in ["ssim", "psnr"]

        best_indices = all_scores.argmax(axis=1)
        action = []
        for i in range(batch_size):
            action.append(all_action_lists[i][best_indices[i]])
        return action