in slowfast/models/video_model_builder.py [0:0]
def __init__(self, cfg):
super().__init__()
# Get parameters.
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE
self.cfg = cfg
pool_first = cfg.MVIT.POOL_FIRST
# Prepare input.
spatial_size = cfg.DATA.TRAIN_CROP_SIZE
temporal_size = cfg.DATA.NUM_FRAMES
in_chans = cfg.DATA.INPUT_CHANNEL_NUM[0]
use_2d_patch = cfg.MVIT.PATCH_2D
self.patch_stride = cfg.MVIT.PATCH_STRIDE
if use_2d_patch:
self.patch_stride = [1] + self.patch_stride
# Prepare output.
num_classes = cfg.MODEL.NUM_CLASSES
embed_dim = cfg.MVIT.EMBED_DIM
# Prepare backbone
num_heads = cfg.MVIT.NUM_HEADS
mlp_ratio = cfg.MVIT.MLP_RATIO
qkv_bias = cfg.MVIT.QKV_BIAS
self.drop_rate = cfg.MVIT.DROPOUT_RATE
depth = cfg.MVIT.DEPTH
drop_path_rate = cfg.MVIT.DROPPATH_RATE
mode = cfg.MVIT.MODE
self.cls_embed_on = cfg.MVIT.CLS_EMBED_ON
self.sep_pos_embed = cfg.MVIT.SEP_POS_EMBED
if cfg.MVIT.NORM == "layernorm":
norm_layer = partial(nn.LayerNorm, eps=1e-6)
else:
raise NotImplementedError("Only supports layernorm.")
self.num_classes = num_classes
self.patch_embed = stem_helper.PatchEmbed(
dim_in=in_chans,
dim_out=embed_dim,
kernel=cfg.MVIT.PATCH_KERNEL,
stride=cfg.MVIT.PATCH_STRIDE,
padding=cfg.MVIT.PATCH_PADDING,
conv_2d=use_2d_patch,
)
self.input_dims = [temporal_size, spatial_size, spatial_size]
assert self.input_dims[1] == self.input_dims[2]
self.patch_dims = [
self.input_dims[i] // self.patch_stride[i]
for i in range(len(self.input_dims))
]
num_patches = math.prod(self.patch_dims)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
if self.cls_embed_on:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
pos_embed_dim = num_patches + 1
else:
pos_embed_dim = num_patches
if self.sep_pos_embed:
self.pos_embed_spatial = nn.Parameter(
torch.zeros(
1, self.patch_dims[1] * self.patch_dims[2], embed_dim
)
)
self.pos_embed_temporal = nn.Parameter(
torch.zeros(1, self.patch_dims[0], embed_dim)
)
if self.cls_embed_on:
self.pos_embed_class = nn.Parameter(
torch.zeros(1, 1, embed_dim)
)
else:
self.pos_embed = nn.Parameter(
torch.zeros(1, pos_embed_dim, embed_dim)
)
if self.drop_rate > 0.0:
self.pos_drop = nn.Dropout(p=self.drop_rate)
dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1)
for i in range(len(cfg.MVIT.DIM_MUL)):
dim_mul[cfg.MVIT.DIM_MUL[i][0]] = cfg.MVIT.DIM_MUL[i][1]
for i in range(len(cfg.MVIT.HEAD_MUL)):
head_mul[cfg.MVIT.HEAD_MUL[i][0]] = cfg.MVIT.HEAD_MUL[i][1]
pool_q = [[] for i in range(cfg.MVIT.DEPTH)]
pool_kv = [[] for i in range(cfg.MVIT.DEPTH)]
stride_q = [[] for i in range(cfg.MVIT.DEPTH)]
stride_kv = [[] for i in range(cfg.MVIT.DEPTH)]
for i in range(len(cfg.MVIT.POOL_Q_STRIDE)):
stride_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_Q_STRIDE[i][
1:
]
if cfg.MVIT.POOL_KVQ_KERNEL is not None:
pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_KVQ_KERNEL
else:
pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = [
s + 1 if s > 1 else s for s in cfg.MVIT.POOL_Q_STRIDE[i][1:]
]
# If POOL_KV_STRIDE_ADAPTIVE is not None, initialize POOL_KV_STRIDE.
if cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE is not None:
_stride_kv = cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE
cfg.MVIT.POOL_KV_STRIDE = []
for i in range(cfg.MVIT.DEPTH):
if len(stride_q[i]) > 0:
_stride_kv = [
max(_stride_kv[d] // stride_q[i][d], 1)
for d in range(len(_stride_kv))
]
cfg.MVIT.POOL_KV_STRIDE.append([i] + _stride_kv)
for i in range(len(cfg.MVIT.POOL_KV_STRIDE)):
stride_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = cfg.MVIT.POOL_KV_STRIDE[
i
][1:]
if cfg.MVIT.POOL_KVQ_KERNEL is not None:
pool_kv[
cfg.MVIT.POOL_KV_STRIDE[i][0]
] = cfg.MVIT.POOL_KVQ_KERNEL
else:
pool_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = [
s + 1 if s > 1 else s
for s in cfg.MVIT.POOL_KV_STRIDE[i][1:]
]
self.norm_stem = norm_layer(embed_dim) if cfg.MVIT.NORM_STEM else None
self.blocks = nn.ModuleList()
if cfg.MODEL.ACT_CHECKPOINT:
validate_checkpoint_wrapper_import(checkpoint_wrapper)
for i in range(depth):
num_heads = round_width(num_heads, head_mul[i])
embed_dim = round_width(embed_dim, dim_mul[i], divisor=num_heads)
dim_out = round_width(
embed_dim,
dim_mul[i + 1],
divisor=round_width(num_heads, head_mul[i + 1]),
)
attention_block = MultiScaleBlock(
dim=embed_dim,
dim_out=dim_out,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_rate=self.drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
kernel_q=pool_q[i] if len(pool_q) > i else [],
kernel_kv=pool_kv[i] if len(pool_kv) > i else [],
stride_q=stride_q[i] if len(stride_q) > i else [],
stride_kv=stride_kv[i] if len(stride_kv) > i else [],
mode=mode,
has_cls_embed=self.cls_embed_on,
pool_first=pool_first,
)
if cfg.MODEL.ACT_CHECKPOINT:
attention_block = checkpoint_wrapper(attention_block)
self.blocks.append(attention_block)
embed_dim = dim_out
self.norm = norm_layer(embed_dim)
self.head = head_helper.TransformerBasicHead(
embed_dim,
num_classes,
dropout_rate=cfg.MODEL.DROPOUT_RATE,
act_func=cfg.MODEL.HEAD_ACT,
)
if self.sep_pos_embed:
trunc_normal_(self.pos_embed_spatial, std=0.02)
trunc_normal_(self.pos_embed_temporal, std=0.02)
if self.cls_embed_on:
trunc_normal_(self.pos_embed_class, std=0.02)
else:
trunc_normal_(self.pos_embed, std=0.02)
if self.cls_embed_on:
trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights)