in activemri/baselines/simple_baselines.py [0:0]
def get_action(self, obs: Dict[str, Any], **_kwargs) -> List[int]:
"""Returns a one-step greedy action maximizing reconstruction score.
Args:
obs(dict(str, any)): As returned by :class:`activemri.envs.ActiveMRIEnv`.
Returns:
list(int): A list of k-space column indices, one per batch element in
the observation, equal to the action that maximizes reconstruction score
(e.g, SSIM or negative MSE).
"""
mask = obs["mask"]
batch_size = mask.shape[0]
all_action_lists = []
for i in range(batch_size):
available_actions = mask[i].logical_not().nonzero().squeeze().tolist()
self.rng.shuffle(available_actions)
if len(available_actions) < self.num_samples:
# Add dummy actions to try if num of samples is higher than the
# number of inactive columns in this mask
available_actions.extend(
[0] * (self.num_samples - len(available_actions))
)
all_action_lists.append(available_actions)
all_scores = np.zeros((batch_size, self.num_samples))
for i in range(self.num_samples):
batch_action_to_try = [action_list[i] for action_list in all_action_lists]
obs, new_score = self.env.try_action(batch_action_to_try)
all_scores[:, i] = new_score[self.metric]
if self.metric in ["mse", "nmse"]:
all_scores *= -1
else:
assert self.metric in ["ssim", "psnr"]
best_indices = all_scores.argmax(axis=1)
action = []
for i in range(batch_size):
action.append(all_action_lists[i][best_indices[i]])
return action