in torchaudio/models/emformer.py [0:0]
def _gen_attention_mask_col_widths(self, seg_idx: int, utterance_length: int) -> List[int]:
num_segs = math.ceil(utterance_length / self.segment_length)
rc = self.right_context_length
lc = self.left_context_length
rc_start = seg_idx * rc
rc_end = rc_start + rc
seg_start = max(seg_idx * self.segment_length - lc, 0)
seg_end = min((seg_idx + 1) * self.segment_length, utterance_length)
rc_length = self.right_context_length * num_segs
if self.use_mem:
m_start = max(seg_idx - self.max_memory_size, 0)
mem_length = num_segs - 1
col_widths = [
m_start, # before memory
seg_idx - m_start, # memory
mem_length - seg_idx, # after memory
rc_start, # before right context
rc, # right context
rc_length - rc_end, # after right context
seg_start, # before query segment
seg_end - seg_start, # query segment
utterance_length - seg_end, # after query segment
]
else:
col_widths = [
rc_start, # before right context
rc, # right context
rc_length - rc_end, # after right context
seg_start, # before query segment
seg_end - seg_start, # query segment
utterance_length - seg_end, # after query segment
]
return col_widths