def get_user_indices()

in flsim/active_user_selectors/diverse_user_selector.py [0:0]


    def get_user_indices(self, **kwargs) -> List[int]:
        required_inputs = [
            "num_total_users",
            "users_per_round",
            "data_provider",
            "global_model",
            "epoch",
        ]
        (
            num_total_users,
            users_per_round,
            data_provider,
            global_model,
            epoch,
        ) = self.unpack_required_inputs(required_inputs, kwargs)

        if len(self.available_users) == 0:
            self.available_users = list(range(num_total_users))

        if (
            # pyre-fixme[16]: `DiversityMaximizingUserSelector` has no attribute `cfg`.
            epoch < self.cfg.epochs_before_active
            or epoch >= self.cfg.epochs_before_active + self.cfg.num_epochs_active
        ):
            baseline_selected_indices = torch.multinomial(
                torch.ones(num_total_users, dtype=torch.float),
                users_per_round,
                replacement=False,
                generator=self.rng,
            ).tolist()
            return baseline_selected_indices

        (candidate_user_indices, _,) = DiverseUserSelectorUtils.select_diverse_cohort(
            data_provider=data_provider,
            global_model=global_model,
            users_per_round=users_per_round,
            available_users=self.available_users,
            rng=self.rng,
            num_search_samples=self.cfg.num_candidate_cohorts,
            maximize_metric=self.cfg.maximize_metric,
            loss_reduction_type=self.cfg.loss_reduction_type,
            client_gradient_scaling=self.cfg.client_gradient_scaling,
            diversity_metric_type=self.cfg.diversity_metric_type,
        )

        if not self.cfg.with_replacement:
            self.available_users = [
                idx for idx in self.available_users if idx not in candidate_user_indices
            ]

        return candidate_user_indices