def load_and_cache_examples()

in biolm/run_classification.py [0:0]


def load_and_cache_examples(args, task, tokenizer, evaluate=False, return_examples=False):
    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    processor = processors[task]()
    if args.task_name.startswith('gad') or args.task_name.startswith('euadr'):
        fold = int(''.join([t for t in args.task_name if t.isdigit()]))
        processor.fold = fold
    output_mode = output_modes[task]
    # Load data features from cache or dataset file
    cached_features_file = os.path.join(
        args.data_dir,
        "cached_{}_{}_{}_{}".format(
            ("test" if args.do_test else "dev") if evaluate else "train",
            list(filter(None, args.model_name_or_path.split("/"))).pop(),
            str(args.max_seq_length),
            str(task),
        ),
    )
    if os.path.exists(cached_features_file) and not args.overwrite_cache:
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
        examples = []
        for line in open(cached_features_file + '.examples.jsonl'):
            examples.append(InputExample(**json.loads(line)))
        assert len(examples) == len(features)

    else:
        logger.info("Creating features from dataset file at %s", args.data_dir)
        label_list = processor.get_labels()
        if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
            # HACK(label indices are swapped in RoBERTa pretrained model)
            label_list[1], label_list[2] = label_list[2], label_list[1]
        examples = (
            (processor.get_test_examples(args.data_dir) if args.do_test else processor.get_dev_examples(args.data_dir))
            if evaluate else processor.get_train_examples(args.data_dir)
        )
        features = convert_examples_to_features(
            examples,
            tokenizer,
            label_list=label_list,
            max_length=args.max_seq_length,
            output_mode=output_mode,
            pad_on_left=bool(args.model_type in ["xlnet"]),  # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
        )
        if args.local_rank in [-1, 0]:
            logger.info("Saving features into cached file %s", cached_features_file)
            torch.save(features, cached_features_file)
            with open(cached_features_file + '.examples.jsonl', 'w') as f:
                for example in examples:
                    dmp = json.dumps(dataclasses.asdict(example), sort_keys=True) + "\n"
                    f.write(dmp)


    if args.local_rank == 0 and not evaluate:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    if output_mode == "classification":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
    elif output_mode == "regression":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
    elif output_mode == 'multilabel_classification':
        all_labels = torch.zeros((len(features), len(processor.get_labels())), dtype=torch.long)
        for feat_no, feat in enumerate(features):
            for l in feat.label:
                all_labels[feat_no, l] = 1
    else:
        raise Exception('')

    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
    if return_examples:
        return dataset, examples
    return dataset