in flsim/active_user_selectors/diverse_user_selector.py [0:0]
def get_user_indices(self, **kwargs) -> List[int]:
required_inputs = [
"num_total_users",
"users_per_round",
"data_provider",
"global_model",
"epoch",
]
(
num_total_users,
users_per_round,
data_provider,
global_model,
epoch,
) = self.unpack_required_inputs(required_inputs, kwargs)
if len(self.available_users) == 0:
self.available_users = list(range(num_total_users))
if (
# pyre-fixme[16]: `DiversityMaximizingUserSelector` has no attribute `cfg`.
epoch < self.cfg.epochs_before_active
or epoch >= self.cfg.epochs_before_active + self.cfg.num_epochs_active
):
baseline_selected_indices = torch.multinomial(
torch.ones(num_total_users, dtype=torch.float),
users_per_round,
replacement=False,
generator=self.rng,
).tolist()
return baseline_selected_indices
(candidate_user_indices, _,) = DiverseUserSelectorUtils.select_diverse_cohort(
data_provider=data_provider,
global_model=global_model,
users_per_round=users_per_round,
available_users=self.available_users,
rng=self.rng,
num_search_samples=self.cfg.num_candidate_cohorts,
maximize_metric=self.cfg.maximize_metric,
loss_reduction_type=self.cfg.loss_reduction_type,
client_gradient_scaling=self.cfg.client_gradient_scaling,
diversity_metric_type=self.cfg.diversity_metric_type,
)
if not self.cfg.with_replacement:
self.available_users = [
idx for idx in self.available_users if idx not in candidate_user_indices
]
return candidate_user_indices