in modeling/model.py [0:0]
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None,
output_hidden_states=None, step=None, mention_labels=None, predict_mention=True, predict_lm=True,
coref_attn=None, batch=None, coref_links=None):
# run gpt2
# last hidden state, (presents), (all hidden_states), (attentions)
transformer_outputs = self.transformer(input_ids, past=past, attention_mask=attention_mask,
token_type_ids=token_type_ids, position_ids=position_ids,
head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache,
output_attentions=True, output_hidden_states=True)
hidden_states = transformer_outputs[0] # (B, T, H)
all_hidden_states = transformer_outputs[2] # tuple of (B, T, H) with len = 1 + n_layer, 1 for embedding
attentions = transformer_outputs[3] # tuple of (B, n_heads, T, T), e.g., attentions[-1][b,n,i,:]
# get lm logits
if predict_lm:
if self.args.task == 'qr_coref' and self.args.use_coref_attn:
if coref_attn is None:
coref_attn = self.collect_coref_hiddens(coref_links, all_hidden_states, batch)
hidden_states_lm = self.attn_on_coref(coref_attn, all_hidden_states, hidden_states)
lm_logits = self.lm_head(hidden_states_lm)
else:
lm_logits = self.lm_head(hidden_states)
# get binary logits
if self.args.use_binary_cls and (step is None or step == 0): # step=None for training, step=0 for first decoding step
bi_logits = self.binary_cls2(self.binary_cls1(hidden_states)) # (B, T, 2)
else:
bi_logits = None
else:
lm_logits, bi_logits = None, None
# get mention detection logits
if predict_mention:
cl_logits = self.cl_head(hidden_states) # (B, T, C)
else:
cl_logits = None
# prepare output
transformer_outputs = transformer_outputs[:-2] # for output consistency, dont return H and A
outputs = (bi_logits, lm_logits, cl_logits, attentions,) + transformer_outputs[1:] # return all attentions
outputs = outputs + (coref_attn,)
# compute loss
if labels is not None:
# qr loss: binary loss and lm loss
if 'qr' in self.args.task:
loss_lm = self._compute_lm_loss(lm_logits, labels, batch)
if self.args.use_binary_cls:
loss_bi = self._compute_binary_loss(bi_logits, batch)
else:
loss_bi = torch.tensor(0).to(self.args.device)
else:
loss_lm = torch.tensor(0).to(self.args.device)
loss_bi = torch.tensor(0).to(self.args.device)
# coref loss: mention loss and reference loss
if 'coref' in self.args.task:
loss_mention = self._compute_mention_loss(cl_logits, mention_labels)
loss_reference = self._compute_reference_loss(batch, attentions)
else:
loss_mention, loss_reference = torch.tensor(0).to(self.args.device), torch.tensor(0).to(self.args.device)
# final loss
loss_total = loss_bi + loss_lm + loss_mention + loss_reference
loss_dict = {'bi': loss_bi, 'lm': loss_lm, 'mention': loss_mention, 'reference': loss_reference, 'total': loss_total}
outputs = (loss_dict,) + outputs
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)