def forward()

in lib/masked_attention.py [0:0]


    def forward(self, input_bte, first_bt, state):
        """Forward propagation of a single layer"""
        state_mask, xf_state = state
        t = first_bt.shape[1]
        if self.mask == "clipped_causal":
            new_mask, state_mask = get_mask(
                first_b11=first_bt[:, [[0]]],
                state_mask=state_mask,
                t=t,
                T=t + self.maxlen,
                maxlen=self.maxlen,
                heads=self.heads,
                device=input_bte.device,
            )
            self.orc_block.attn.mask = new_mask
        output, xf_state = self.orc_block(input_bte, xf_state)

        return output, (state_mask, xf_state)