in flsim/utils/example_utils.py [0:0]
def build_data_provider(local_batch_size, examples_per_user, image_size):
# 1. Create training, eval, and test datasets like in non-federated learning.
transform = transforms.Compose(
[
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
train_dataset = CIFAR10(
root="./cifar10", train=True, download=True, transform=transform
)
test_dataset = CIFAR10(
root="./cifar10", train=False, download=True, transform=transform
)
# 2. Create a sharder, which maps samples in the training data to clients.
sharder = SequentialSharder(examples_per_shard=examples_per_user)
# 3. Shard and batchify training, eval, and test data.
fl_data_loader = DataLoader(
train_dataset=train_dataset,
eval_dataset=test_dataset,
test_dataset=test_dataset,
sharder=sharder,
batch_size=local_batch_size,
drop_last=False,
)
# 4. Wrap the data loader with a data provider.
data_provider = DataProvider(fl_data_loader)
print(f"Clients in total: {data_provider.num_users()}")
return data_provider