in src/model.py [0:0]
def forward(self, hidden_states, attention_mask, position_bias, **kwargs):
if self.use_checkpoint and self.training:
kwargs = {k: v for k, v in kwargs.items() if v is not None}
def custom_forward(*inputs):
output = self.module(*inputs, **kwargs)
empty = torch.tensor(
[],
dtype=torch.float,
device=output[0].device,
requires_grad=True)
output = tuple(x if x is not None else empty for x in output)
return output
output = torch.utils.checkpoint.checkpoint(
custom_forward,
hidden_states,
attention_mask,
position_bias
)
output = tuple(x if x.size() != 0 else None for x in output)
else:
output = self.module(hidden_states, attention_mask, position_bias, **kwargs)
return output