in low_rank_comparisons/src/model.py [0:0]
def forward(self, input_ids, lm_labels=None, lm_mask=None, past=None, len_past=None, label_smooth=0.0, is_meta_index=False, meta_cls_index=None,
is_report_accuracy=False):
_batch, _len = input_ids.shape
#print(input_ids)
hidden_states, presents = self.transformer(input_ids, past=past, len_past=len_past) #, position_ids, token_type_ids, past=past, len_past=len_past)
#if hidden_states.device.index == 0:
# print(hidden_states[0][0][:100])
# if _len > 5:
# print(hidden_states[0][3][:100])
# print(hidden_states[0][4][:100])
# print(hidden_states[0][5][:100])
#print()
if is_meta_index:
_b = torch.arange(0, input_ids.shape[0], dtype=torch.long, device=input_ids.device)
# hidden_states : batch, seq, dim
_inputs = []
for _i in self.meta_inputs:
_inputs.append(presents[_i][_b, meta_cls_index, :])
_input = torch.cat(_inputs, dim=1)
if self.meta_mlp is not None:
_output = self.meta_mlp(_input)
else:
_output = _input
return _output
# batch, seq, vocab
lm_logits = self.lm_head(hidden_states)
if lm_labels is not None:
if is_report_accuracy:
_pred_token = torch.argmax(lm_logits, dim=-1)
_hit = (_pred_token == lm_labels) * lm_mask
_t1_acc = torch.zeros(_batch, dtype=torch.float, device=input_ids.device)
_all_acc = torch.zeros(_batch, dtype=torch.float, device=input_ids.device)
for _b in range(0, _batch):
for _i in range(0, _len):
if lm_mask[_b, _i] >= 1.0:
if _hit[_b, _i] > 0:
_t1_acc[_b] = 1.0
break
_is_succ = True
for _i in range(0, _len):
if lm_mask[_b, _i] >= 1.0:
if _hit[_b, _i] <= 0:
_is_succ = False
break
if _is_succ:
_all_acc[_b] = 1.0
#_t1_acc = _t1_acc * 1.0 / _batch
#_all_acc = _all_acc * 1.0 / _batch
if label_smooth > 0.0001:
logprobs = torch.nn.functional.log_softmax(lm_logits.view(-1, lm_logits.size(-1)), dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=lm_labels.view(-1).unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = (1.0 - label_smooth) * nll_loss + label_smooth * smooth_loss
loss = loss.view(_batch, _len)
else:
loss_fct = nn.CrossEntropyLoss(ignore_index=-1, reduce=False)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)).view(_batch, _len)
if lm_mask is None:
lm_mask = torch.ones(loss.shape, dtype=loss.dtype, device=loss.device)
loss = loss * lm_mask
loss = loss.sum() / (lm_mask.sum() + 0.0001)
if is_report_accuracy:
return lm_logits, loss, _t1_acc, _all_acc
else:
return lm_logits, loss
return lm_logits, presents