def get_user_indices()

in flsim/active_user_selectors/diverse_user_selector.py [0:0]


    def get_user_indices(self, **kwargs) -> List[int]:
        required_inputs = [
            "num_total_users",
            "users_per_round",
            "data_provider",
            "global_model",
            "epoch",
        ]
        (
            num_total_users,
            users_per_round,
            data_provider,
            global_model,
            epoch,
        ) = self.unpack_required_inputs(required_inputs, kwargs)

        selected_indices = torch.multinomial(
            torch.ones(num_total_users, dtype=torch.float),
            users_per_round,
            replacement=False,
            generator=self.rng,
        ).tolist()

        currently_active = (
            # pyre-fixme[16]: `DiversityReportingUserSelector` has no attribute `cfg`.
            epoch >= self.cfg.epochs_before_active
            and epoch < self.cfg.epochs_before_active + self.cfg.num_epochs_active
        )

        if not currently_active:
            return selected_indices

        if len(self.sample_cohorts) == 0 or not self.cfg.constant_cohorts:
            self.sample_cohorts = []
            for _ in range(self.cfg.num_candidate_cohorts):
                sample_cohort_indices = torch.multinomial(
                    torch.ones(num_total_users, dtype=torch.float),
                    users_per_round,
                    replacement=False,
                    generator=self.rng,
                ).tolist()

                self.sample_cohorts.append(sample_cohort_indices)

        diversity_metrics_list = []
        for sample_cohort in self.sample_cohorts:
            diversity_metrics_list.append(
                DiverseUserSelectorUtils.calculate_diversity_metrics(
                    data_provider=data_provider,
                    global_model=global_model,
                    user_indices=sample_cohort,
                    loss_reduction_type=self.cfg.loss_reduction_type,
                    client_gradient_scaling=self.cfg.client_gradient_scaling,
                    diversity_metric_type=self.cfg.diversity_metric_type,
                )
            )

        return selected_indices