in pt/vmz/models/r2plus1d.py [0:0]
def r2plus1d_152(pretraining="", use_pool1=True, progress=False, **kwargs):
avail_pretrainings = [
"ig65m_32frms",
"ig_ft_kinetics_32frms",
"sports1m_32frms",
"sports1m_ft_kinetics_32frms",
]
if pretraining in avail_pretrainings:
arch = "r2plus1d_" + pretraining
pretrained = True
else:
warnings.warn(
f"Unrecognized pretraining dataset, continuing with randomly initialized network."
" Available pretrainings: {avail_pretrainings}",
UserWarning,
)
arch = "r2plus1d_34"
pretrained = False
model = _generic_resnet(
arch,
pretrained,
progress,
block=Bottleneck,
conv_makers=[Conv2Plus1D] * 4,
layers=[3, 8, 36, 3],
stem=R2Plus1dStem_Pool if use_pool1 else R2Plus1dStem,
**kwargs,
)
# We need exact Caffe2 momentum for BatchNorm scaling
for m in model.modules():
if isinstance(m, nn.BatchNorm3d):
m.eps = 1e-3
m.momentum = 0.9
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
model_urls[arch], progress=progress
)
model.load_state_dict(state_dict)
return model