def fl_train_set()

in flsim/data/dataset_data_loader.py [0:0]


    def fl_train_set(self, **kwargs) -> Iterable[Iterable[Any]]:
        self._num_total_users = 0
        rank = kwargs.get("rank", 0)
        world_size = kwargs.get("world_size", 1)

        train_batches = [
            user_data for _, user_data in self.sharder.shard_rows(self.train_dataset)
        ]
        # batch train_batches collected above
        final_train_batches = []
        # fetch attributes for each row
        keys = list(train_batches[0][0].keys())
        for one_user_data in train_batches:
            batched_user_data = []
            for i, single_data in enumerate(one_user_data):
                if i % self.train_batch_size == 0:
                    batched_user_data.append([])
                batched_user_data[-1].append(single_data)

            new_batched_user_data = []
            for a_batched_user_data in batched_user_data:
                batched_data_rows = {}
                for key in keys:
                    batched_data_rows[key] = []
                for single_user_data in a_batched_user_data:
                    for key in keys:
                        batched_data_rows[key].append(single_user_data[key])

                for key in keys:
                    batched_data_rows[key] = torch.stack(batched_data_rows[key])

                new_batched_user_data.append(batched_data_rows)
            # divide the total number of users evenly into world_size # of workers
            if self.num_total_users % world_size == rank:
                final_train_batches.append(new_batched_user_data)
            # count the total number of users
            self._num_total_users += 1

        return final_train_batches