in vilbert/task_utils.py [0:0]
def LoadDatasetEval(args, task_cfg, ids):
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
task_feature_reader1 = {}
task_feature_reader2 = {}
for i, task_id in enumerate(ids):
task = "TASK" + task_id
if task_cfg[task]["features_h5path1"] not in task_feature_reader1:
task_feature_reader1[task_cfg[task]["features_h5path1"]] = None
if task_cfg[task]["features_h5path2"] not in task_feature_reader2:
task_feature_reader2[task_cfg[task]["features_h5path2"]] = None
# initilzie the feature reader
for features_h5path in task_feature_reader1.keys():
if features_h5path != "":
task_feature_reader1[features_h5path] = ImageFeaturesH5Reader(
features_h5path, args.in_memory
)
for features_h5path in task_feature_reader2.keys():
if features_h5path != "":
task_feature_reader2[features_h5path] = ImageFeaturesH5Reader(
features_h5path, args.in_memory
)
task_datasets_val = {}
task_dataloader_val = {}
task_ids = []
task_batch_size = {}
task_num_iters = {}
for i, task_id in enumerate(ids):
task = "TASK" + task_id
task_ids.append(task)
task_name = task_cfg[task]["name"]
batch_size = args.batch_size
if args.local_rank != -1:
batch_size = int(batch_size / dist.get_world_size())
num_workers = int(args.num_workers / len(ids))
logger.info(
"Loading %s Dataset with batch size %d"
% (task_cfg[task]["name"], batch_size)
)
if args.split:
eval_split = args.split
else:
eval_split = task_cfg[task]["val_split"]
task_datasets_val[task] = DatasetMapEval[task_name](
task=task_cfg[task]["name"],
dataroot=task_cfg[task]["dataroot"],
annotations_jsonpath=task_cfg[task]["val_annotations_jsonpath"],
split=eval_split,
image_features_reader=task_feature_reader1[
task_cfg[task]["features_h5path1"]
],
gt_image_features_reader=task_feature_reader2[
task_cfg[task]["features_h5path2"]
],
tokenizer=tokenizer,
bert_model=args.bert_model,
clean_datasets=args.clean_train_sets,
padding_index=0,
max_seq_length=task_cfg[task]["max_seq_length"],
max_region_num=task_cfg[task]["max_region_num"],
)
task_dataloader_val[task] = DataLoader(
task_datasets_val[task],
shuffle=False,
batch_size=batch_size,
num_workers=10,
pin_memory=True,
)
task_num_iters[task] = len(task_dataloader_val[task])
task_batch_size[task] = batch_size
return (
task_batch_size,
task_num_iters,
task_ids,
task_datasets_val,
task_dataloader_val,
)