scripts/adapet/ADAPET/src/data/tokenize.py (97 lines of code) (raw):

import numpy as np import math import random from collections import defaultdict def tokenize_pet_mlm_txt(tokenizer, config, txt1, txt2, txt3, txt_trim, mask_idx=None): ''' Tokenizes the text by trimming the appropriate txt :param tokenizer: param config: :param txt1: :param txt2: :param txt3: :param mask_txt1: :param mask_txt2: :param mask_txt3: :param txt_trim: idx of text to trim will never contain label :return mask_idx: list of list of idx of mask token in trunc_input_ids (in case lbl is more than 1 token) ''' txt1_input_ids = tokenizer(txt1, add_special_tokens=False)["input_ids"] txt2_input_ids = tokenizer(txt2, add_special_tokens=False)["input_ids"] txt3_input_ids = tokenizer(txt3, add_special_tokens=False)["input_ids"] # Add 1 to account for CLS rep tot_length = len(txt1_input_ids) + len(txt2_input_ids) + len(txt3_input_ids) + 1 # Don't need to trim text if tot_length <= config.max_text_length: trunc_input_ids = [tokenizer.pad_token_id] * config.max_text_length trunc_input_ids[:tot_length] = txt1_input_ids + txt2_input_ids + txt3_input_ids # Trim text else: num_trim = tot_length - config.max_text_length if txt_trim == 0: new_txt1_input_ids = txt1_input_ids[:-num_trim] trunc_input_ids = new_txt1_input_ids + txt2_input_ids + txt3_input_ids elif txt_trim == 1: new_txt2_input_ids = txt2_input_ids[:-num_trim] trunc_input_ids = txt1_input_ids + new_txt2_input_ids + txt3_input_ids elif txt_trim == 2: new_txt_3_input_ids = txt3_input_ids[:-num_trim] trunc_input_ids = txt1_input_ids + txt2_input_ids + new_txt_3_input_ids else: raise ValueError("Invalid Txt Trim") trunc_input_ids = [tokenizer.cls_token_id] + trunc_input_ids if mask_idx is None: sample_length = min(tot_length, config.max_text_length) upto_ratio_mask = np.random.rand() num_sample = max(int(upto_ratio_mask * config.mask_alpha * sample_length), 2) - 1 mask_idx = random.sample(range(0, sample_length), k=num_sample) mask_idx = np.asarray(mask_idx) # Copy adds mask idx at random positions unsup_masked_ids = np.copy(trunc_input_ids) unsup_masked_ids[mask_idx] = tokenizer.mask_token_id return trunc_input_ids, unsup_masked_ids, mask_idx def tokenize_pet_txt(tokenizer, config, txt1, txt2, txt3, mask_txt1, mask_txt2, mask_txt3, txt_trim): ''' Tokenizes the text by trimming the appropriate txt :param txt1: :param txt2: :param txt3: :param mask_txt1: :param mask_txt2: :param mask_txt3: :param txt_trim: text to trim will never contain label :return trunc_input_ids: list of input ids (each exactly max_config_length) :return mask_idx: list of list of idx of mask token in trunc_input_ids (in case lbl is more than 1 token) ''' txt1_input_ids = tokenizer(txt1, add_special_tokens=False)["input_ids"] txt2_input_ids = tokenizer(txt2, add_special_tokens=False)["input_ids"] txt3_input_ids = tokenizer(txt3, add_special_tokens=False)["input_ids"] mask_txt1_input_ids = tokenizer(mask_txt1, add_special_tokens=False)["input_ids"] mask_txt2_input_ids = tokenizer(mask_txt2, add_special_tokens=False)["input_ids"] mask_txt3_input_ids = tokenizer(mask_txt3, add_special_tokens=False)["input_ids"] # Add 1 to account for CLS rep tot_length = len(txt1_input_ids) + len(txt2_input_ids) + len(txt3_input_ids) + 1 tot_mask_length = len(mask_txt1_input_ids) + len(mask_txt2_input_ids) + len(mask_txt3_input_ids) + 1 # Don't need to trim text if tot_length <= config.max_text_length: trunc_input_ids = [tokenizer.pad_token_id] * config.max_text_length trunc_input_ids[:tot_length] = txt1_input_ids + txt2_input_ids + txt3_input_ids trunc_mask_input_ids = [tokenizer.pad_token_id] * config.max_text_length trunc_mask_input_ids[:tot_mask_length] = mask_txt1_input_ids + mask_txt2_input_ids + mask_txt3_input_ids # Trim text else: num_trim = tot_length - config.max_text_length if txt_trim == 0: new_txt1_input_ids = txt1_input_ids[:-num_trim] new_mask_txt1_input_ids = mask_txt1_input_ids[:-num_trim] trunc_input_ids = new_txt1_input_ids + txt2_input_ids + txt3_input_ids trunc_mask_input_ids = new_mask_txt1_input_ids + mask_txt2_input_ids + mask_txt3_input_ids elif txt_trim == 1: new_txt2_input_ids = txt2_input_ids[:-num_trim] new_mask_txt2_input_ids = mask_txt2_input_ids[:-num_trim] trunc_input_ids = txt1_input_ids + new_txt2_input_ids + txt3_input_ids trunc_mask_input_ids = mask_txt1_input_ids + new_mask_txt2_input_ids + mask_txt3_input_ids elif txt_trim == 2: new_txt_3_input_ids = txt3_input_ids[:-num_trim] new_mask_txt3_input_ids = mask_txt3_input_ids[:-num_trim] trunc_input_ids = txt1_input_ids + txt2_input_ids + new_txt_3_input_ids trunc_mask_input_ids = mask_txt1_input_ids + mask_txt2_input_ids + new_mask_txt3_input_ids else: raise ValueError("Invalid Txt Trim") trunc_input_ids = [tokenizer.cls_token_id] + trunc_input_ids trunc_mask_input_ids = [tokenizer.cls_token_id] + trunc_mask_input_ids mask_idx = trunc_mask_input_ids.index(tokenizer.mask_token_id) return trunc_input_ids, mask_idx