in flsim/active_user_selectors/diverse_user_selector.py [0:0]
def get_user_indices(self, **kwargs) -> List[int]:
required_inputs = [
"num_total_users",
"users_per_round",
"data_provider",
"global_model",
"epoch",
]
(
num_total_users,
users_per_round,
data_provider,
global_model,
epoch,
) = self.unpack_required_inputs(required_inputs, kwargs)
selected_indices = torch.multinomial(
torch.ones(num_total_users, dtype=torch.float),
users_per_round,
replacement=False,
generator=self.rng,
).tolist()
currently_active = (
# pyre-fixme[16]: `DiversityReportingUserSelector` has no attribute `cfg`.
epoch >= self.cfg.epochs_before_active
and epoch < self.cfg.epochs_before_active + self.cfg.num_epochs_active
)
if not currently_active:
return selected_indices
if len(self.sample_cohorts) == 0 or not self.cfg.constant_cohorts:
self.sample_cohorts = []
for _ in range(self.cfg.num_candidate_cohorts):
sample_cohort_indices = torch.multinomial(
torch.ones(num_total_users, dtype=torch.float),
users_per_round,
replacement=False,
generator=self.rng,
).tolist()
self.sample_cohorts.append(sample_cohort_indices)
diversity_metrics_list = []
for sample_cohort in self.sample_cohorts:
diversity_metrics_list.append(
DiverseUserSelectorUtils.calculate_diversity_metrics(
data_provider=data_provider,
global_model=global_model,
user_indices=sample_cohort,
loss_reduction_type=self.cfg.loss_reduction_type,
client_gradient_scaling=self.cfg.client_gradient_scaling,
diversity_metric_type=self.cfg.diversity_metric_type,
)
)
return selected_indices