def forward()

in opacus_lab/models/GPT2/model/masking.py [0:0]


    def forward(self, x: torch.Tensor, offset: int = 0) -> torch.Tensor:
        is_pad = (x == self.pad_idx).unsqueeze(-2)
        shifted = torch.zeros(
            x.size()[:-1]
            + (
                1,
                offset,
            ),
            dtype=torch.bool,
            device=x.device,
        )

        mask = torch.cat((shifted, is_pad), dim=-1)
        return mask.expand(x.shape + mask.shape[-1:])