in classy_vision/models/resnext3d.py [0:0]
def _parse_config(config):
ret_config = {}
required_args = [
"input_planes",
"clip_crop_size",
"skip_transformation_type",
"residual_transformation_type",
"frames_per_clip",
"num_blocks",
]
for arg in required_args:
assert arg in config, "resnext3d model requires argument %s" % arg
ret_config[arg] = config[arg]
# Default setting for model stem, which is considered as stage 0. Stage
# index starts from 0 as implemented in ResStageBase._block_name() method.
# stem_planes: No. of output channles of conv op in stem
# stem_temporal_kernel: temporal size of conv op in stem
# stem_spatial_kernel: spatial size of conv op in stem
# stem_maxpool: by default, spatial maxpool op is disabled in stem
ret_config.update(
{
"input_key": config.get("input_key", None),
"stem_name": config.get("stem_name", "resnext3d_stem"),
"stem_planes": config.get("stem_planes", 64),
"stem_temporal_kernel": config.get("stem_temporal_kernel", 3),
"stem_spatial_kernel": config.get("stem_spatial_kernel", 7),
"stem_maxpool": config.get("stem_maxpool", False),
}
)
# Default setting for model stages 1, 2, 3 and 4
# stage_planes: No. of output channel of 1st conv op in stage 1
# stage_temporal_kernel_basis: Basis of temporal kernel sizes for each of
# the stage.
# temporal_conv_1x1: if True, do temporal convolution in the fist
# 1x1 Conv3d. Otherwise, do it in the second 3x3 Conv3d (default settting)
# stage_temporal_stride: temporal stride for each stage
# stage_spatial_stride: spatial stride for each stage
# num_groups: No. of groups in 2nd (group) conv in the residual transformation
# width_per_group: No. of channels per group in 2nd (group) conv in the
# residual transformation
ret_config.update(
{
"stage_planes": config.get("stage_planes", 256),
"stage_temporal_kernel_basis": config.get(
"stage_temporal_kernel_basis", [[3], [3], [3], [3]]
),
"temporal_conv_1x1": config.get(
"temporal_conv_1x1", [False, False, False, False]
),
"stage_temporal_stride": config.get(
"stage_temporal_stride", [1, 2, 2, 2]
),
"stage_spatial_stride": config.get(
"stage_spatial_stride", [1, 2, 2, 2]
),
"num_groups": config.get("num_groups", 1),
"width_per_group": config.get("width_per_group", 64),
}
)
# Default setting for model parameter initialization
ret_config.update(
{
"zero_init_residual_transform": config.get(
"zero_init_residual_transform", False
)
}
)
assert is_pos_int_list(ret_config["num_blocks"])
assert is_pos_int(ret_config["stem_planes"])
assert is_pos_int(ret_config["stem_temporal_kernel"])
assert is_pos_int(ret_config["stem_spatial_kernel"])
assert type(ret_config["stem_maxpool"]) == bool
assert is_pos_int(ret_config["stage_planes"])
assert isinstance(ret_config["stage_temporal_kernel_basis"], Sequence)
assert all(
is_pos_int_list(l) for l in ret_config["stage_temporal_kernel_basis"]
)
assert isinstance(ret_config["temporal_conv_1x1"], Sequence)
assert is_pos_int_list(ret_config["stage_temporal_stride"])
assert is_pos_int_list(ret_config["stage_spatial_stride"])
assert is_pos_int(ret_config["num_groups"])
assert is_pos_int(ret_config["width_per_group"])
return ret_config