def get_data_loaders()

in src/datatuner/lm/data_loader.py [0:0]


def get_data_loaders(args, task_config, tokenizer):
    """ Prepare the dataset for training and evaluation """
    datasets_raw = {}
    logger.info("Loading training data")

    if args.local_rank not in [-1, 0]:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()
        args.ignore_cache = False

    for split in ["validation", "train"]:
        logger.info(f"Loading {split} data")
        datasets_raw[split] = get_dataset(
            tokenizer,
            args.dataset_cache,
            task_config,
            args.dataset_path,
            split,
            args.max_data if split == "train" else args.val_max_data,
            args.ignore_cache,
            args.max_block_size,
        )

    logger.info("Build inputs and labels")
    datasets = {"train": defaultdict(list), "validation": defaultdict(list)}

    for dataset_name, dataset in datasets_raw.items():
        # get the last learnt field
        last_learnt_field = [x["id"] for x in task_config["data_shape"][::-1] if x["learn"] and x["type"] == "text"][0]

        if args.multitask:
            assert type(dataset[0][last_learnt_field]) == list
            num_candidates = len(dataset[0][last_learnt_field])
        else:
            num_candidates = 1

        if args.num_candidates > 0 and dataset_name in ["train", "validation"]:
            num_candidates = min(args.num_candidates, num_candidates)

        for data_point in dataset:
            if type(data_point[last_learnt_field]) == str:
                data_point[last_learnt_field] = [data_point[last_learnt_field]]

            for j, candidate_val in enumerate(data_point[last_learnt_field][-num_candidates:]):
                # the last item in the array is the ground truth. For other distractor items, we mask the LM labels
                mask_lm_labels = bool(j != num_candidates - 1)
                instance, _ = build_input_from_segments(
                    data_point,
                    tokenizer,
                    task_config,
                    mask_lm_labels=mask_lm_labels,
                    last_learnt_field=last_learnt_field,
                    candidate_val=candidate_val,
                    max_block_size=args.max_block_size,
                )
                if args.multitask:
                    # this is an indicator for the last input token, used in the Double Head model
                    instance["mc_token_ids"] = len(instance["input_ids"]) - 1

                for input_name, input_array in instance.items():
                    datasets[dataset_name][input_name].append(input_array)

            datasets[dataset_name]["n_candidates"] = num_candidates

            # the ground truth is the last item in the array; previous items are distractors
            if args.multitask:
                datasets[dataset_name]["mc_labels"].append(num_candidates - 1)

    logger.info("Pad inputs and convert to Tensor")
    tensor_datasets = {"train": [], "validation": []}
    for dataset_name, dataset in datasets.items():
        dataset = pad_dataset(dataset, padding=tokenizer.convert_tokens_to_ids(PAD_TOKEN))
        for input_name in MODEL_INPUTS:
            if input_name in dataset:
                tensor = torch.tensor(dataset[input_name])
                if input_name != "mc_labels":
                    # adjust the shape as we might have more than one candidate in the case of DoubleHeads
                    try:
                        tensor = tensor.view((-1, datasets[dataset_name]["n_candidates"]) + tensor.shape[1:])
                    except:
                        import ipdb

                        ipdb.set_trace()
                tensor_datasets[dataset_name].append(tensor)

    logger.info("Build train and validation dataloaders")
    train_dataset, valid_dataset = (
        TensorDataset(*tensor_datasets["train"]),
        TensorDataset(*tensor_datasets["validation"]),
    )
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
    valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None
    train_loader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed)
    )
    valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False)

    logger.info("Train dataset (Batch, Candidates, Seq length): {}".format(train_dataset.tensors[0].shape))
    logger.info("validation dataset (Batch, Candidates, Seq length): {}".format(valid_dataset.tensors[0].shape))

    if args.local_rank == 0:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

    return train_loader, valid_loader, train_sampler, valid_sampler