in esm/model.py [0:0]
def max_tokens_per_msa_(self, value: int) -> None:
"""The MSA Transformer automatically batches attention computations when
gradients are disabled to allow you to pass in larger MSAs at test time than
you can fit in GPU memory. By default this occurs when more than 2^14 tokens
are passed in the input MSA. You can set this value to infinity to disable
this behavior.
"""
for module in self.modules():
if isinstance(module, (RowSelfAttention, ColumnSelfAttention)):
module.max_tokens_per_msa = value