in train.py [0:0]
def get_sample_for_visualization(data, preprocess_fn, num, dataset):
for x in DataLoader(data, batch_size=num):
break
orig_image = (x[0] * 255.0).to(torch.uint8).permute(0, 2, 3, 1) if dataset == 'ffhq_1024' else x[0]
preprocessed = preprocess_fn(x)[0]
return orig_image, preprocessed