in utils/main_utils.py [0:0]
def build_dataloader(db_cfg, split_cfg, num_workers, distributed):
import torch.utils.data as data
import torch.utils.data.distributed
from datasets import preprocessing
import datasets
# Video transforms
num_frames = int(db_cfg['video_clip_duration'] * db_cfg['video_fps'])
if db_cfg['transforms'] == 'crop+color':
video_transform = preprocessing.VideoPrep_Crop_CJ(
resize=db_cfg['frame_size'],
crop=(db_cfg['crop_size'], db_cfg['crop_size']),
augment=split_cfg['use_augmentation'],
num_frames=num_frames,
pad_missing=True,
)
elif db_cfg['transforms'] == 'msc+color':
video_transform = preprocessing.VideoPrep_MSC_CJ(
crop=(db_cfg['crop_size'], db_cfg['crop_size']),
augment=split_cfg['use_augmentation'],
num_frames=num_frames,
pad_missing=True,
)
else:
raise ValueError('Unknown transform')
# Audio transforms
audio_transforms = [
preprocessing.AudioPrep(
trim_pad=True,
duration=db_cfg['audio_clip_duration'],
augment=split_cfg['use_augmentation'],
missing_as_zero=True),
preprocessing.LogSpectrogram(
db_cfg['audio_fps'],
n_fft=db_cfg['n_fft'],
hop_size=1. / db_cfg['spectrogram_fps'],
normalize=True)
]
audio_fps_out = db_cfg['spectrogram_fps']
if db_cfg['name'] == 'audioset':
dataset = datasets.AudioSet
elif db_cfg['name'] == 'kinetics':
dataset = datasets.Kinetics
else:
raise ValueError('Unknown dataset')
clips_per_video = split_cfg['clips_per_video'] if 'clips_per_video' in split_cfg else 1
db = dataset(
subset=split_cfg['split'],
return_video=True,
video_clip_duration=db_cfg['video_clip_duration'],
video_fps=db_cfg['video_fps'],
video_transform=video_transform,
return_audio=True,
audio_clip_duration=db_cfg['audio_clip_duration'],
audio_fps=db_cfg['audio_fps'],
audio_fps_out=audio_fps_out,
audio_transform=audio_transforms,
max_offsync_augm=0.5 if split_cfg['use_augmentation'] else 0,
return_labels=False,
return_index=True,
mode='clip',
clips_per_video=clips_per_video,
)
if distributed:
sampler = torch.utils.data.distributed.DistributedSampler(db)
else:
sampler = None
loader = torch.utils.data.DataLoader(
db,
batch_size=db_cfg['batch_size'],
shuffle=(sampler is None),
drop_last=split_cfg['drop_last'],
num_workers=num_workers,
pin_memory=True,
sampler=sampler)
return loader