def tokenize_pet_txt()

in scripts/adapet/ADAPET/src/data/tokenize.py [0:0]


def tokenize_pet_txt(tokenizer, config, txt1, txt2, txt3, mask_txt1, mask_txt2, mask_txt3, txt_trim):
    '''
    Tokenizes the text by trimming the appropriate txt

    :param txt1:
    :param txt2:
    :param txt3:
    :param mask_txt1:
    :param mask_txt2:
    :param mask_txt3:
    :param txt_trim: text to trim will never contain label
    :return trunc_input_ids: list of input ids (each exactly max_config_length)
    :return mask_idx: list of list of idx of mask token in trunc_input_ids (in case lbl is more than 1 token)
    '''
    txt1_input_ids = tokenizer(txt1, add_special_tokens=False)["input_ids"]
    txt2_input_ids = tokenizer(txt2, add_special_tokens=False)["input_ids"]
    txt3_input_ids = tokenizer(txt3, add_special_tokens=False)["input_ids"]

    mask_txt1_input_ids = tokenizer(mask_txt1, add_special_tokens=False)["input_ids"]
    mask_txt2_input_ids = tokenizer(mask_txt2, add_special_tokens=False)["input_ids"]
    mask_txt3_input_ids = tokenizer(mask_txt3, add_special_tokens=False)["input_ids"]

    # Add 1 to account for CLS rep
    tot_length = len(txt1_input_ids) + len(txt2_input_ids) + len(txt3_input_ids) + 1
    tot_mask_length = len(mask_txt1_input_ids) + len(mask_txt2_input_ids) + len(mask_txt3_input_ids) + 1

    # Don't need to trim text
    if tot_length <= config.max_text_length:
        trunc_input_ids = [tokenizer.pad_token_id] * config.max_text_length
        trunc_input_ids[:tot_length] = txt1_input_ids + txt2_input_ids + txt3_input_ids

        trunc_mask_input_ids = [tokenizer.pad_token_id] * config.max_text_length
        trunc_mask_input_ids[:tot_mask_length] = mask_txt1_input_ids + mask_txt2_input_ids + mask_txt3_input_ids

    # Trim text
    else:
        num_trim = tot_length - config.max_text_length

        if txt_trim == 0:
            new_txt1_input_ids = txt1_input_ids[:-num_trim]
            new_mask_txt1_input_ids = mask_txt1_input_ids[:-num_trim]
            trunc_input_ids = new_txt1_input_ids + txt2_input_ids + txt3_input_ids
            trunc_mask_input_ids = new_mask_txt1_input_ids + mask_txt2_input_ids + mask_txt3_input_ids
        elif txt_trim == 1:
            new_txt2_input_ids = txt2_input_ids[:-num_trim]
            new_mask_txt2_input_ids = mask_txt2_input_ids[:-num_trim]
            trunc_input_ids = txt1_input_ids + new_txt2_input_ids + txt3_input_ids
            trunc_mask_input_ids = mask_txt1_input_ids + new_mask_txt2_input_ids + mask_txt3_input_ids
        elif txt_trim == 2:
            new_txt_3_input_ids = txt3_input_ids[:-num_trim]
            new_mask_txt3_input_ids = mask_txt3_input_ids[:-num_trim]
            trunc_input_ids = txt1_input_ids + txt2_input_ids + new_txt_3_input_ids
            trunc_mask_input_ids = mask_txt1_input_ids + mask_txt2_input_ids + new_mask_txt3_input_ids
        else:
            raise ValueError("Invalid Txt Trim")


    trunc_input_ids = [tokenizer.cls_token_id] + trunc_input_ids
    trunc_mask_input_ids = [tokenizer.cls_token_id] + trunc_mask_input_ids

    mask_idx = trunc_mask_input_ids.index(tokenizer.mask_token_id)

    return trunc_input_ids, mask_idx