in jukebox/utils/io.py [0:0]
def test_dataset_loader():
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from jukebox.utils.audio_utils import audio_preprocess, audio_postprocess
from jukebox.hparams import setup_hparams
from jukebox.data.files_dataset import FilesAudioDataset
hps = setup_hparams("teeny", {})
hps.sr = 22050 # 44100
hps.hop_length = 512
hps.labels = False
hps.channels = 2
hps.aug_shift = False
hps.bs = 2
hps.nworkers = 2 # Getting 20 it/s with 2 workers, 10 it/s with 1 worker
print(hps)
dataset = hps.dataset
root = hps.root
from tensorboardX import SummaryWriter
sr = {22050: '22k', 44100: '44k', 48000: '48k'}[hps.sr]
writer = SummaryWriter(f'{root}/{dataset}/logs/{sr}/logs')
dataset = FilesAudioDataset(hps)
print("Length of dataset", len(dataset))
# Torch Loader
collate_fn = lambda batch: t.stack([t.from_numpy(b) for b in batch], 0)
sampler = DistributedSampler(dataset)
train_loader = DataLoader(dataset, batch_size=hps.bs, num_workers=hps.nworkers, pin_memory=False, sampler=sampler,
drop_last=True, collate_fn=collate_fn)
dist.barrier()
sampler.set_epoch(0)
for i, x in enumerate(tqdm(train_loader)):
x = x.to('cuda', non_blocking=True)
for j, aud in enumerate(x):
writer.add_audio('in_' + str(i*hps.bs + j), aud, 1, hps.sr)
print("Wrote in")
x = audio_preprocess(x, hps)
x = audio_postprocess(x, hps)
for j, aud in enumerate(x):
writer.add_audio('out_' + str(i*hps.bs + j), aud, 1, hps.sr)
print("Wrote out")
dist.barrier()
break