in timesformer/models/video_model_builder.py [0:0]
def _construct_network(self, cfg):
"""
Builds a single pathway X3D model.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
assert cfg.MODEL.ARCH in _POOL1.keys()
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys()
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH]
num_groups = cfg.RESNET.NUM_GROUPS
width_per_group = cfg.RESNET.WIDTH_PER_GROUP
dim_inner = num_groups * width_per_group
w_mul = cfg.X3D.WIDTH_FACTOR
d_mul = cfg.X3D.DEPTH_FACTOR
dim_res1 = self._round_width(self.dim_c1, w_mul)
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH]
self.s1 = stem_helper.VideoModelStem(
dim_in=cfg.DATA.INPUT_CHANNEL_NUM,
dim_out=[dim_res1],
kernel=[temp_kernel[0][0] + [3, 3]],
stride=[[1, 2, 2]],
padding=[[temp_kernel[0][0][0] // 2, 1, 1]],
norm_module=self.norm_module,
stem_func_name="x3d_stem",
)
# blob_in = s1
dim_in = dim_res1
for stage, block in enumerate(self.block_basis):
dim_out = self._round_width(block[1], w_mul)
dim_inner = int(cfg.X3D.BOTTLENECK_FACTOR * dim_out)
n_rep = self._round_repeats(block[0], d_mul)
prefix = "s{}".format(
stage + 2
) # start w res2 to follow convention
s = resnet_helper.ResStage(
dim_in=[dim_in],
dim_out=[dim_out],
dim_inner=[dim_inner],
temp_kernel_sizes=temp_kernel[1],
stride=[block[2]],
num_blocks=[n_rep],
num_groups=[dim_inner]
if cfg.X3D.CHANNELWISE_3x3x3
else [num_groups],
num_block_temp_kernel=[n_rep],
nonlocal_inds=cfg.NONLOCAL.LOCATION[0],
nonlocal_group=cfg.NONLOCAL.GROUP[0],
nonlocal_pool=cfg.NONLOCAL.POOL[0],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
norm_module=self.norm_module,
dilation=cfg.RESNET.SPATIAL_DILATIONS[stage],
drop_connect_rate=cfg.MODEL.DROPCONNECT_RATE
* (stage + 2)
/ (len(self.block_basis) + 1),
)
dim_in = dim_out
self.add_module(prefix, s)
if self.enable_detection:
NotImplementedError
else:
spat_sz = int(math.ceil(cfg.DATA.TRAIN_CROP_SIZE / 32.0))
self.head = head_helper.X3DHead(
dim_in=dim_out,
dim_inner=dim_inner,
dim_out=cfg.X3D.DIM_C5,
num_classes=cfg.MODEL.NUM_CLASSES,
pool_size=[cfg.DATA.NUM_FRAMES, spat_sz, spat_sz],
dropout_rate=cfg.MODEL.DROPOUT_RATE,
act_func=cfg.MODEL.HEAD_ACT,
bn_lin5_on=cfg.X3D.BN_LIN5,
)