in src/datatuner/lm/data_loader.py [0:0]
def get_data_loaders(args, task_config, tokenizer):
""" Prepare the dataset for training and evaluation """
datasets_raw = {}
logger.info("Loading training data")
if args.local_rank not in [-1, 0]:
# Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier()
args.ignore_cache = False
for split in ["validation", "train"]:
logger.info(f"Loading {split} data")
datasets_raw[split] = get_dataset(
tokenizer,
args.dataset_cache,
task_config,
args.dataset_path,
split,
args.max_data if split == "train" else args.val_max_data,
args.ignore_cache,
args.max_block_size,
)
logger.info("Build inputs and labels")
datasets = {"train": defaultdict(list), "validation": defaultdict(list)}
for dataset_name, dataset in datasets_raw.items():
# get the last learnt field
last_learnt_field = [x["id"] for x in task_config["data_shape"][::-1] if x["learn"] and x["type"] == "text"][0]
if args.multitask:
assert type(dataset[0][last_learnt_field]) == list
num_candidates = len(dataset[0][last_learnt_field])
else:
num_candidates = 1
if args.num_candidates > 0 and dataset_name in ["train", "validation"]:
num_candidates = min(args.num_candidates, num_candidates)
for data_point in dataset:
if type(data_point[last_learnt_field]) == str:
data_point[last_learnt_field] = [data_point[last_learnt_field]]
for j, candidate_val in enumerate(data_point[last_learnt_field][-num_candidates:]):
# the last item in the array is the ground truth. For other distractor items, we mask the LM labels
mask_lm_labels = bool(j != num_candidates - 1)
instance, _ = build_input_from_segments(
data_point,
tokenizer,
task_config,
mask_lm_labels=mask_lm_labels,
last_learnt_field=last_learnt_field,
candidate_val=candidate_val,
max_block_size=args.max_block_size,
)
if args.multitask:
# this is an indicator for the last input token, used in the Double Head model
instance["mc_token_ids"] = len(instance["input_ids"]) - 1
for input_name, input_array in instance.items():
datasets[dataset_name][input_name].append(input_array)
datasets[dataset_name]["n_candidates"] = num_candidates
# the ground truth is the last item in the array; previous items are distractors
if args.multitask:
datasets[dataset_name]["mc_labels"].append(num_candidates - 1)
logger.info("Pad inputs and convert to Tensor")
tensor_datasets = {"train": [], "validation": []}
for dataset_name, dataset in datasets.items():
dataset = pad_dataset(dataset, padding=tokenizer.convert_tokens_to_ids(PAD_TOKEN))
for input_name in MODEL_INPUTS:
if input_name in dataset:
tensor = torch.tensor(dataset[input_name])
if input_name != "mc_labels":
# adjust the shape as we might have more than one candidate in the case of DoubleHeads
try:
tensor = tensor.view((-1, datasets[dataset_name]["n_candidates"]) + tensor.shape[1:])
except:
import ipdb
ipdb.set_trace()
tensor_datasets[dataset_name].append(tensor)
logger.info("Build train and validation dataloaders")
train_dataset, valid_dataset = (
TensorDataset(*tensor_datasets["train"]),
TensorDataset(*tensor_datasets["validation"]),
)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None
train_loader = DataLoader(
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed)
)
valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False)
logger.info("Train dataset (Batch, Candidates, Seq length): {}".format(train_dataset.tensors[0].shape))
logger.info("validation dataset (Batch, Candidates, Seq length): {}".format(valid_dataset.tensors[0].shape))
if args.local_rank == 0:
# Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier()
return train_loader, valid_loader, train_sampler, valid_sampler