in scripts/adapet/ADAPET/src/adapet.py [0:0]
def __init__(self, config, tokenizer, dataset_reader):
'''
ADAPET model
:param config
'''
super(adapet, self).__init__()
self.config = config
self.tokenizer = tokenizer
self.dataset_reader = dataset_reader
pretrained_file = os.path.join("pretrained_models", self.config.pretrained_weight)
if not os.path.exists(pretrained_file):
pretrained_file = self.config.pretrained_weight
if "albert" in pretrained_file:
albert_config = AlbertConfig.from_pretrained(pretrained_file)
self.model = AlbertForMaskedLM.from_pretrained(pretrained_file, config=albert_config)
else:
self.model = AutoModelForMaskedLM.from_pretrained(pretrained_file)
self.num_lbl = self.dataset_reader.get_num_lbl()
# Mask Idx Lkup hack to compute the loss at mask positions
init_mask_idx_lkup = torch.cat([torch.eye(self.config.max_text_length), torch.zeros((1, self.config.max_text_length))], dim=0)
self.mask_idx_lkup = nn.Embedding.from_pretrained(init_mask_idx_lkup) # [max_text_length+1, max_text_length]
self.num_lbl = self.dataset_reader.get_num_lbl()
self.lbl_idx_lkup = nn.Embedding.from_pretrained(torch.eye(self.num_lbl)) # [num_lbl, num_lbl]
self.loss = nn.BCELoss(reduction="none")
# Setup patterns depending on if random or not
self.pattern_list = self.dataset_reader.dataset_reader.pets
if config.pattern_idx == "random":
self.pattern = lambda: random.choice(self.pattern_list)
else:
assert config.pattern_idx > 0 and config.pattern_idx <= len(self.pattern_list), "This dataset has {} patterns".format(len(self.pattern_list))
self.pattern = self.pattern_list[config.pattern_idx-1]