vision/smolvlm2/smolvlm/model/varlen_packing.py (90 lines of code) (raw):
# varlen_packing.py
# ---------------------------------------------------
# A universal patch for "flash_attention_2" to handle
# multi-subsequence "packed" sequences via integer-coded
# or block-coded 1D attention maskszq
# ---------------------------------------------------
import torch
import torch.nn.functional as F
from transformers.utils import logging
logger = logging.get_logger(__name__)
def _get_seqlens_in_batch(mask_1d: torch.Tensor) -> torch.Tensor:
"""
Convert a 1D integer-coded mask (like [1,1,1,2,2,2,2,3,3,3,0,0,...])
into sub-sequence lengths. We assume sub-sequence IDs appear in ascending
order but do not revisit older IDs. Each contiguous run of a nonzero ID
counts toward that sub-sequence's length.
Example:
mask_1d = [1,1,1,2,2,2,0,0]
=> sub-seq #1 => length=3, #2 => length=3
=> lengths = [3, 3]
Returns a 1D int32 of sub-sequence lengths, e.g. [3,3].
"""
mask_1d = mask_1d.view(-1) # flatten
# Filter out zeros (which is "padding")
nonzero_mask = mask_1d[mask_1d != 0]
if nonzero_mask.numel() == 0:
# no real tokens
return torch.tensor([], dtype=torch.int32)
lengths = []
count = 1
last_id = nonzero_mask[0].item()
for val in nonzero_mask[1:]:
vid = val.item()
if vid == last_id:
count += 1
else:
lengths.append(count)
last_id = vid
count = 1
if count > 0:
lengths.append(count)
return torch.tensor(lengths, dtype=torch.int32)
def get_unpad_data(attention_mask: torch.Tensor):
"""
Our custom override for varlen "flash_attention_2".
Typically `_get_unpad_data` returns:
(indices, cu_seqlens, max_seqlen_in_batch)
We interpret `attention_mask` as a 2D or 1D integer-coded array:
shape => (batch_size, seq_len) or (seq_len,)
For each row, we parse sub-seq lengths => build cu_seqlens => build indices.
Example for a single row [1,1,1,2,2,2,2,0,0]:
=> sub-seq #1 => length=3, sub-seq #2 => length=4
=> cu_seqlens => [0,3,7], max_len => 4
=> indices => positions that are !=0 => [0,1,2,3,4,5,6]
If multiple rows => do row by row, then unify.
We also forcibly move the returned Tensors to the same device as
`attention_mask` if needed. (Essential for "cu_seqlens_q must be on CUDA".)
"""
dev = attention_mask.device # We'll force everything onto this device at the end
#import ipdb; ipdb.set_trace()
if attention_mask.dim() == 1:
# Single row
mask_flat = attention_mask
lengths = _get_seqlens_in_batch(mask_flat)
if lengths.numel() == 0:
# no real tokens
indices = torch.tensor([], dtype=torch.long, device=dev)
cu_seqlens = torch.tensor([0], dtype=torch.int32, device=dev)
return (indices, cu_seqlens, 0)
cu_seqlens = torch.cat([
torch.tensor([0], dtype=torch.int32, device=dev),
torch.cumsum(lengths, dim=0).to(dev)
], dim=0)
max_len = lengths.max().item()
indices = (mask_flat != 0).nonzero().squeeze(-1).to(dev)
return (indices, cu_seqlens, max_len)
elif attention_mask.dim() == 2:
bsz, seqlen = attention_mask.shape
indices_list = []
cu_seqlens_list = [0]
current_offset = 0
max_len = 0
for row_idx in range(bsz):
row = attention_mask[row_idx]
lengths = _get_seqlens_in_batch(row)
if lengths.numel() > 0:
new_cu = torch.cumsum(lengths, dim=0) + cu_seqlens_list[-1]
cu_seqlens_list.extend(new_cu.tolist())
row_max = lengths.max().item()
if row_max > max_len:
max_len = row_max
else:
# no real tokens => skip
pass
row_indices = (row != 0).nonzero().squeeze(-1) + current_offset
indices_list.append(row_indices)
current_offset += seqlen
if len(cu_seqlens_list) == 1:
# means no real tokens at all
indices = torch.tensor([], dtype=torch.long, device=dev)
cu_seqlens = torch.tensor([0], dtype=torch.int32, device=dev)
return (indices, cu_seqlens, 0)
# Build final Tensors
indices = torch.cat(indices_list, dim=0).to(dev)
cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int32, device=dev)
return (indices, cu_seqlens, max_len)
else:
raise ValueError(
f"_my_get_unpad_data_varlen expects dim=1 or 2, got shape {attention_mask.shape}"
)
def apply_varlen_patch():
"""
Monkey-patch HF's `_get_unpad_data` with `_my_get_unpad_data_varlen`.
This modifies the varlen logic for "flash_attention_2".
"""
try:
from transformers import modeling_flash_attention_utils
except ImportError:
logger.warning(
"apply_varlen_patch: transformers>=4.45 needed for flash_attention_2. Not patching."
)
return None
if not hasattr(modeling_flash_attention_utils, "_get_unpad_data"):
logger.warning(
"apply_varlen_patch: can't find `_get_unpad_data` in modeling_flash_attention_utils. "
"Your Transformers version might not have flash_attn varlen logic."
)
return None
# Replace
old_func = modeling_flash_attention_utils._get_unpad_data
modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info(
"apply_varlen_patch: Replaced `_get_unpad_data` with our varlen integer-coded approach."
)
return old_func # If you want to restore it later