in scripts/adapet/ADAPET/src/data/tokenize.py [0:0]
def tokenize_pet_txt(tokenizer, config, txt1, txt2, txt3, mask_txt1, mask_txt2, mask_txt3, txt_trim):
'''
Tokenizes the text by trimming the appropriate txt
:param txt1:
:param txt2:
:param txt3:
:param mask_txt1:
:param mask_txt2:
:param mask_txt3:
:param txt_trim: text to trim will never contain label
:return trunc_input_ids: list of input ids (each exactly max_config_length)
:return mask_idx: list of list of idx of mask token in trunc_input_ids (in case lbl is more than 1 token)
'''
txt1_input_ids = tokenizer(txt1, add_special_tokens=False)["input_ids"]
txt2_input_ids = tokenizer(txt2, add_special_tokens=False)["input_ids"]
txt3_input_ids = tokenizer(txt3, add_special_tokens=False)["input_ids"]
mask_txt1_input_ids = tokenizer(mask_txt1, add_special_tokens=False)["input_ids"]
mask_txt2_input_ids = tokenizer(mask_txt2, add_special_tokens=False)["input_ids"]
mask_txt3_input_ids = tokenizer(mask_txt3, add_special_tokens=False)["input_ids"]
# Add 1 to account for CLS rep
tot_length = len(txt1_input_ids) + len(txt2_input_ids) + len(txt3_input_ids) + 1
tot_mask_length = len(mask_txt1_input_ids) + len(mask_txt2_input_ids) + len(mask_txt3_input_ids) + 1
# Don't need to trim text
if tot_length <= config.max_text_length:
trunc_input_ids = [tokenizer.pad_token_id] * config.max_text_length
trunc_input_ids[:tot_length] = txt1_input_ids + txt2_input_ids + txt3_input_ids
trunc_mask_input_ids = [tokenizer.pad_token_id] * config.max_text_length
trunc_mask_input_ids[:tot_mask_length] = mask_txt1_input_ids + mask_txt2_input_ids + mask_txt3_input_ids
# Trim text
else:
num_trim = tot_length - config.max_text_length
if txt_trim == 0:
new_txt1_input_ids = txt1_input_ids[:-num_trim]
new_mask_txt1_input_ids = mask_txt1_input_ids[:-num_trim]
trunc_input_ids = new_txt1_input_ids + txt2_input_ids + txt3_input_ids
trunc_mask_input_ids = new_mask_txt1_input_ids + mask_txt2_input_ids + mask_txt3_input_ids
elif txt_trim == 1:
new_txt2_input_ids = txt2_input_ids[:-num_trim]
new_mask_txt2_input_ids = mask_txt2_input_ids[:-num_trim]
trunc_input_ids = txt1_input_ids + new_txt2_input_ids + txt3_input_ids
trunc_mask_input_ids = mask_txt1_input_ids + new_mask_txt2_input_ids + mask_txt3_input_ids
elif txt_trim == 2:
new_txt_3_input_ids = txt3_input_ids[:-num_trim]
new_mask_txt3_input_ids = mask_txt3_input_ids[:-num_trim]
trunc_input_ids = txt1_input_ids + txt2_input_ids + new_txt_3_input_ids
trunc_mask_input_ids = mask_txt1_input_ids + mask_txt2_input_ids + new_mask_txt3_input_ids
else:
raise ValueError("Invalid Txt Trim")
trunc_input_ids = [tokenizer.cls_token_id] + trunc_input_ids
trunc_mask_input_ids = [tokenizer.cls_token_id] + trunc_mask_input_ids
mask_idx = trunc_mask_input_ids.index(tokenizer.mask_token_id)
return trunc_input_ids, mask_idx