def forward()

in modeling/model.py [0:0]


	def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None,
				head_mask=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None,
				output_hidden_states=None, step=None, mention_labels=None, predict_mention=True, predict_lm=True,
				coref_attn=None, batch=None, coref_links=None):

		# run gpt2
		# last hidden state, (presents), (all hidden_states), (attentions)
		transformer_outputs = self.transformer(input_ids, past=past, attention_mask=attention_mask,
												token_type_ids=token_type_ids, position_ids=position_ids,
												head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache,
												output_attentions=True, output_hidden_states=True)

		hidden_states = transformer_outputs[0] # (B, T, H)
		all_hidden_states = transformer_outputs[2] # tuple of (B, T, H) with len = 1 + n_layer, 1 for embedding
		attentions = transformer_outputs[3] # tuple of (B, n_heads, T, T), e.g., attentions[-1][b,n,i,:]

		# get lm logits
		if predict_lm:
			if self.args.task == 'qr_coref' and self.args.use_coref_attn:
				if coref_attn is None:
					coref_attn = self.collect_coref_hiddens(coref_links, all_hidden_states, batch)

				hidden_states_lm = self.attn_on_coref(coref_attn, all_hidden_states, hidden_states)
				lm_logits = self.lm_head(hidden_states_lm)

			else:
				lm_logits = self.lm_head(hidden_states)

			# get binary logits
			if self.args.use_binary_cls and (step is None or step == 0): # step=None for training, step=0 for first decoding step
				bi_logits = self.binary_cls2(self.binary_cls1(hidden_states)) # (B, T, 2)
			else:
				bi_logits = None
		else:
			lm_logits, bi_logits = None, None

		# get mention detection logits
		if predict_mention:
			cl_logits = self.cl_head(hidden_states) # (B, T, C)
		else:
			cl_logits = None

		# prepare output
		transformer_outputs = transformer_outputs[:-2] # for output consistency, dont return H and A
		outputs = (bi_logits, lm_logits, cl_logits, attentions,) + transformer_outputs[1:] # return all attentions
		outputs = outputs + (coref_attn,)

		# compute loss
		if labels is not None:
			# qr loss: binary loss and lm loss
			if 'qr' in self.args.task:
				loss_lm = self._compute_lm_loss(lm_logits, labels, batch)

				if self.args.use_binary_cls:
					loss_bi = self._compute_binary_loss(bi_logits, batch)
				else:
					loss_bi = torch.tensor(0).to(self.args.device)
			else:
				loss_lm = torch.tensor(0).to(self.args.device)
				loss_bi = torch.tensor(0).to(self.args.device)

			# coref loss: mention loss and reference loss
			if 'coref' in self.args.task:
				loss_mention = self._compute_mention_loss(cl_logits, mention_labels)
				loss_reference = self._compute_reference_loss(batch, attentions)
			else:
				loss_mention, loss_reference = torch.tensor(0).to(self.args.device), torch.tensor(0).to(self.args.device)

			# final loss
			loss_total = loss_bi + loss_lm + loss_mention + loss_reference
			loss_dict = {'bi': loss_bi, 'lm': loss_lm, 'mention': loss_mention, 'reference': loss_reference, 'total': loss_total}
			outputs = (loss_dict,) + outputs

		return outputs  # (loss), lm_logits, presents, (all hidden_states), (attentions)