in flsim/data/dataset_data_loader.py [0:0]
def fl_train_set(self, **kwargs) -> Iterable[Iterable[Any]]:
self._num_total_users = 0
rank = kwargs.get("rank", 0)
world_size = kwargs.get("world_size", 1)
train_batches = [
user_data for _, user_data in self.sharder.shard_rows(self.train_dataset)
]
# batch train_batches collected above
final_train_batches = []
# fetch attributes for each row
keys = list(train_batches[0][0].keys())
for one_user_data in train_batches:
batched_user_data = []
for i, single_data in enumerate(one_user_data):
if i % self.train_batch_size == 0:
batched_user_data.append([])
batched_user_data[-1].append(single_data)
new_batched_user_data = []
for a_batched_user_data in batched_user_data:
batched_data_rows = {}
for key in keys:
batched_data_rows[key] = []
for single_user_data in a_batched_user_data:
for key in keys:
batched_data_rows[key].append(single_user_data[key])
for key in keys:
batched_data_rows[key] = torch.stack(batched_data_rows[key])
new_batched_user_data.append(batched_data_rows)
# divide the total number of users evenly into world_size # of workers
if self.num_total_users % world_size == rank:
final_train_batches.append(new_batched_user_data)
# count the total number of users
self._num_total_users += 1
return final_train_batches