in datasets.py [0:0]
def get_dataset(config):
"""Returns train/val dataloaders."""
config['seq_len'] = config['n_ctx'] + config['n_steps']
normalize = True
if config['dataset'] == 'stochastic':
from dataloaders.stochastic_mmnist import MovingMNIST
train_dataset = MovingMNIST(
True,
seq_len=config['seq_len'],
deterministic=False
)
val_dataset = MovingMNIST(
False,
seq_len=config['seq_len'],
deterministic=False
)
img_ch = 1
elif config['dataset'] == 'pushbair':
from dataloaders.pushbair_loader import PushDataset
train_dataset = PushDataset(
'train',
config['seq_len'],
normalize=normalize,
)
val_dataset = PushDataset(
'test',
config['seq_len'],
normalize=normalize,
)
img_ch = 3
elif config['dataset'] == 'pushbair_fvd':
from dataloaders.pushbair_fvd_loader import PushDataset
train_dataset = PushDataset(
'train',
config['seq_len'],
normalize=normalize,
)
val_dataset = PushDataset(
'test',
config['seq_len'],
normalize=normalize,
)
img_ch = 3
elif config['dataset'] == 'cityscapes':
from dataloaders.cityscapes_loader import CityscapesDataset
train_dataset = CityscapesDataset(
'train_64',
config['seq_len'],
img_side=64,
normalize=normalize,
resize=False,
)
val_dataset = CityscapesDataset(
'test_64',
config['seq_len'],
img_side=64,
normalize=normalize,
resize=False,
)
img_ch = 3
elif config['dataset'] == 'cityscapes128':
from dataloaders.cityscapes_loader import CityscapesDataset
train_dataset = CityscapesDataset(
'train_128',
config['seq_len'],
img_side=128,
normalize=normalize,
resize=False,
)
val_dataset = CityscapesDataset(
'test_128',
config['seq_len'],
img_side=128,
normalize=normalize,
resize=False,
)
img_ch = 3
def init_fun(worker_id):
return np.random.seed()
if config['multigpu']:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
train_loader = DataLoader(
train_dataset,
batch_size=config['batch_size'],
sampler=train_sampler,
# shuffle=True,
num_workers=config['n_workers'],
worker_init_fn=init_fun
)
val_loader = DataLoader(
val_dataset,
batch_size=config['batch_size'],
sampler=val_sampler,
# shuffle=False,
num_workers=config['n_workers'],
worker_init_fn=init_fun
)
else:
train_loader = DataLoader(
train_dataset,
batch_size=config['batch_size'],
shuffle=True,
num_workers=config['n_workers'],
worker_init_fn=init_fun,
pin_memory=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=config['batch_size'],
shuffle=True,
# shuffle=False,
num_workers=config['n_workers'],
worker_init_fn=init_fun,
)
config['img_ch'] = img_ch
config['batches_per_epoch'] = len(train_dataset)//config['batch_size']
return train_loader, val_loader