def __call__()

in activemri/experimental/cvpr19_models/data/masking_utils.py [0:0]


    def __call__(self, shape, seed=None):
        if len(shape) < 3:
            raise ValueError("Shape should have 3 or more dimensions")

        self.rng.seed(seed)
        num_cols = shape[-2]

        # Determine number of low and high frequency lines to scan
        if self.random_num_lines:
            # These are guaranteed to be an even number (useful for symmetric masks)
            num_low_freqs = self.rng.choice(
                range(self.min_lowf_lines, self.max_lowf_lines, 2)
            )
            num_high_freqs = (
                int(
                    self.rng.beta(self.highf_beta_alpha, self.highf_beta_beta)
                    * (num_cols - num_low_freqs)
                    // 2
                )
                * 2
            )
        else:
            choice = self.rng.randint(0, len(self.accelerations))
            center_fraction = self.center_fractions[choice]
            acceleration = self.accelerations[choice]

            num_low_freqs = int(round(num_cols * center_fraction))
            num_high_freqs = int(num_cols // acceleration - num_low_freqs)

        # Create the mask
        mask = self.create_lf_focused_mask(num_cols, num_high_freqs, num_low_freqs)

        # Reshape the mask
        mask_shape = [1 for _ in shape]
        mask_shape[-1] = num_cols
        mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32))
        return mask