def inference()

in src/nli/inference_debug.py [0:0]


def inference(model_class_name, model_checkpoint_path, max_length, premise, hypothesis, cpu=True):
    parser = argparse.ArgumentParser()
    args = parser.parse_args()

    # CPU for now
    if cpu:
        args.global_rank = -1
    else:
        args.global_rank = 0

    model_checkpoint_path = model_checkpoint_path
    args.model_class_name = model_class_name
    num_labels = 3
    # we are doing NLI so we set num_labels = 3, for other task we can change this value.

    max_length = max_length

    model_class_item = MODEL_CLASSES[model_class_name]
    model_name = model_class_item['model_name']
    do_lower_case = model_class_item['do_lower_case'] if 'do_lower_case' in model_class_item else False

    tokenizer = model_class_item['tokenizer'].from_pretrained(model_name,
                                                              cache_dir=str(config.PRO_ROOT / "trans_cache"),
                                                              do_lower_case=do_lower_case)

    model = model_class_item['sequence_classification'].from_pretrained(model_name,
                                                                        cache_dir=str(config.PRO_ROOT / "trans_cache"),
                                                                        num_labels=num_labels)

    model.load_state_dict(torch.load(model_checkpoint_path))

    padding_token_value = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0]
    padding_segement_value = model_class_item["padding_segement_value"]
    padding_att_value = model_class_item["padding_att_value"]
    left_pad = model_class_item['left_pad'] if 'left_pad' in model_class_item else False

    batch_size_per_gpu_eval = 16

    eval_data_list = [{
        'uid': str(uuid.uuid4()),
        'premise': premise,
        'hypothesis': hypothesis,
        'label': 'h'    # hidden
    }]

    batching_schema = {
        'uid': RawFlintField(),
        'y': LabelFlintField(),
        'input_ids': ArrayIndexFlintField(pad_idx=padding_token_value, left_pad=left_pad),
        'token_type_ids': ArrayIndexFlintField(pad_idx=padding_segement_value, left_pad=left_pad),
        'attention_mask': ArrayIndexFlintField(pad_idx=padding_att_value, left_pad=left_pad),
    }

    data_transformer = NLITransform(model_name, tokenizer, max_length)

    d_dataset, d_sampler, d_dataloader = build_eval_dataset_loader_and_sampler(eval_data_list, data_transformer,
                                                                               batching_schema,
                                                                               batch_size_per_gpu_eval)

    if not cpu:
        torch.cuda.set_device(0)
        model.cuda(0)

    pred_output_list = eval_model(model, d_dataloader, args.global_rank, args)
    # r_dict = dict()
    # Eval loop:
    # print(pred_output_list)
    return pred_output_list[0]