in lib/masked_attention.py [0:0]
def forward(self, input_bte, first_bt, state):
"""Forward propagation of a single layer"""
state_mask, xf_state = state
t = first_bt.shape[1]
if self.mask == "clipped_causal":
new_mask, state_mask = get_mask(
first_b11=first_bt[:, [[0]]],
state_mask=state_mask,
t=t,
T=t + self.maxlen,
maxlen=self.maxlen,
heads=self.heads,
device=input_bte.device,
)
self.orc_block.attn.mask = new_mask
output, xf_state = self.orc_block(input_bte, xf_state)
return output, (state_mask, xf_state)