def get_user_indices()

in flsim/active_user_selectors/diverse_user_selector.py [0:0]


    def get_user_indices(self, **kwargs) -> List[int]:
        required_inputs = [
            "num_total_users",
            "users_per_round",
            "data_provider",
            "global_model",
            "epoch",
        ]
        (
            num_total_users,
            users_per_round,
            data_provider,
            global_model,
            epoch,
        ) = self.unpack_required_inputs(required_inputs, kwargs)

        if len(self.available_users) == 0:
            self.available_users = list(range(num_total_users))

        # First uniformly randomly select the actual cohort used for training
        baseline_selected_indices = torch.multinomial(
            torch.ones(num_total_users, dtype=torch.float),
            users_per_round,
            replacement=False,
            generator=self.rng,
        ).tolist()

        if (
            # pyre-fixme[16]: `DiversityStatisticsReportingUserSelector` has no
            #  attribute `cfg`.
            epoch < self.cfg.epochs_before_active
            or epoch >= self.cfg.epochs_before_active + self.cfg.num_epochs_active
        ):
            return baseline_selected_indices

        (_, stat_with_replacement) = DiverseUserSelectorUtils.select_diverse_cohort(
            data_provider=data_provider,
            global_model=global_model,
            users_per_round=users_per_round,
            available_users=list(range(num_total_users)),
            rng=self.rng,
            num_search_samples=self.cfg.num_candidate_cohorts,
            maximize_metric=self.cfg.maximize_metric,
            loss_reduction_type=self.cfg.loss_reduction_type,
            client_gradient_scaling=self.cfg.client_gradient_scaling,
            diversity_metric_type=self.cfg.diversity_metric_type,
        )

        (
            candidate_user_indices,
            stat_without_replacement,
        ) = DiverseUserSelectorUtils.select_diverse_cohort(
            data_provider=data_provider,
            global_model=global_model,
            users_per_round=users_per_round,
            available_users=self.available_users,
            rng=self.rng,
            num_search_samples=self.cfg.num_candidate_cohorts,
            maximize_metric=self.cfg.maximize_metric,
            loss_reduction_type=self.cfg.loss_reduction_type,
            client_gradient_scaling=self.cfg.client_gradient_scaling,
            diversity_metric_type=self.cfg.diversity_metric_type,
        )

        self.cohort_stats_with_replacement.append(stat_with_replacement)
        self.cohort_stats_without_replacement.append(stat_without_replacement)

        # Update the list of available users for round-robin selector
        self.available_users = [
            idx for idx in self.available_users if idx not in candidate_user_indices
        ]

        return baseline_selected_indices