in flsim/active_user_selectors/simple_user_selector.py [0:0]
def get_user_indices(self, **kwargs) -> List[int]:
required_inputs = [
"num_total_users",
"users_per_round",
"data_provider",
"global_model",
"epoch",
]
(
num_total_users,
users_per_round,
data_provider,
global_model,
epoch,
) = self.unpack_required_inputs(required_inputs, kwargs)
# pyre-fixme[16]: `HighLossActiveUserSelector` has no attribute `cfg`.
if epoch < self.cfg.epochs_before_active:
selected_indices = self._non_active_sampling(
num_total_users, users_per_round
)
return selected_indices
if self.user_losses.nelement() == 0:
(
self.user_losses,
self.user_sample_counts,
) = self._get_initial_losses_and_counts(
num_total_users, data_provider, global_model
)
user_utility = ActiveUserSelectorUtils.normalize_by_sample_count(
user_utility=self.user_losses,
user_sample_counts=self.user_sample_counts,
averaging_exponent=self.cfg.count_normalization_exponent,
)
probs = ActiveUserSelectorUtils.convert_to_probability(
user_utility=user_utility,
fraction_with_zero_prob=self.cfg.fraction_with_zero_prob,
softmax_temperature=self.cfg.softmax_temperature,
)
selected_indices = ActiveUserSelectorUtils.select_users(
users_per_round=users_per_round,
probs=probs,
fraction_uniformly_random=self.cfg.fraction_uniformly_random,
rng=self.rng,
)
for i in selected_indices:
(
self.user_losses[i],
self.user_sample_counts[i],
) = self._get_user_loss_and_sample_count(
data_provider.get_user_data(i), global_model
)
return selected_indices