in opacus_lab/models/GPT2/model/masking.py [0:0]
def forward(self, x: torch.Tensor, offset: int = 0) -> torch.Tensor:
is_pad = (x == self.pad_idx).unsqueeze(-2)
shifted = torch.zeros(
x.size()[:-1]
+ (
1,
offset,
),
dtype=torch.bool,
device=x.device,
)
mask = torch.cat((shifted, is_pad), dim=-1)
return mask.expand(x.shape + mask.shape[-1:])