in training/data.py [0:0]
def __init__(self, **kwargs):
if 'questions_h5' not in kwargs:
raise ValueError('Must give questions_h5')
if 'data_json' not in kwargs:
raise ValueError('Must give data_json')
if 'vocab' not in kwargs:
raise ValueError('Must give vocab')
if 'input_type' not in kwargs:
raise ValueError('Must give input_type')
if 'split' not in kwargs:
raise ValueError('Must give split')
if 'gpu_id' not in kwargs:
raise ValueError('Must give gpu_id')
questions_h5_path = kwargs.pop('questions_h5')
data_json = kwargs.pop('data_json')
input_type = kwargs.pop('input_type')
split = kwargs.pop('split')
vocab = kwargs.pop('vocab')
gpu_id = kwargs.pop('gpu_id')
if 'max_threads_per_gpu' in kwargs:
max_threads_per_gpu = kwargs.pop('max_threads_per_gpu')
else:
max_threads_per_gpu = 10
if 'to_cache' in kwargs:
to_cache = kwargs.pop('to_cache')
else:
to_cache = False
if 'target_obj_conn_map_dir' in kwargs:
target_obj_conn_map_dir = kwargs.pop('target_obj_conn_map_dir')
else:
target_obj_conn_map_dir = False
if 'map_resolution' in kwargs:
map_resolution = kwargs.pop('map_resolution')
else:
map_resolution = 1000
if 'image' in input_type or 'cnn' in input_type:
kwargs['collate_fn'] = eqaCollateCnn
elif 'lstm' in input_type:
kwargs['collate_fn'] = eqaCollateSeq2seq
if 'overfit' in kwargs:
overfit = kwargs.pop('overfit')
else:
overfit = False
if 'max_controller_actions' in kwargs:
max_controller_actions = kwargs.pop('max_controller_actions')
else:
max_controller_actions = 5
if 'max_actions' in kwargs:
max_actions = kwargs.pop('max_actions')
else:
max_actions = None
print('Reading questions from ', questions_h5_path)
with h5py.File(questions_h5_path, 'r') as questions_h5:
self.dataset = EqaDataset(
questions_h5,
vocab,
num_frames=kwargs.pop('num_frames'),
data_json=data_json,
split=split,
gpu_id=gpu_id,
input_type=input_type,
max_threads_per_gpu=max_threads_per_gpu,
to_cache=to_cache,
target_obj_conn_map_dir=target_obj_conn_map_dir,
map_resolution=map_resolution,
overfit=overfit,
max_controller_actions=max_controller_actions,
max_actions=max_actions)
super(EqaDataLoader, self).__init__(self.dataset, **kwargs)