in network/resnet101_3d_gcn_x5.py [0:0]
def __init__(self, num_classes, pretrained=False, **kwargs):
super(RESNET101_3D_GCN_X5, self).__init__()
groups = 1
k_sec = { 2: 3, \
3: 4, \
4: 23, \
5: 3 }
# conv1 - x112 (x16)
conv1_num_out = 32
self.conv1 = nn.Sequential(OrderedDict([
('conv', nn.Conv3d( 3, conv1_num_out, kernel_size=(3,5,5), padding=(1,2,2), stride=(1,2,2), bias=False)),
('bn', nn.BatchNorm3d(conv1_num_out, eps=1e-04)),
('relu', nn.ReLU(inplace=True)),
('max_pool', nn.MaxPool3d(kernel_size=(1,3,3), padding=(0,1,1), stride=(1,2,2))),
]))
# conv2 - x56 (x16)
num_mid = 64
conv2_num_out = 256
self.conv2 = nn.Sequential(OrderedDict([
("B%02d"%i, RESIDUAL_BLOCK(num_in=conv1_num_out if i==1 else conv2_num_out,
num_mid=num_mid,
num_out=conv2_num_out,
stride=(1,1,1) if i==1 else (1,1,1),
g=groups,
first_block=(i==1))) for i in range(1,k_sec[2]+1)
]))
# conv3 - x28 (x8)
num_mid *= 2
conv3_num_out = 2 * conv2_num_out
blocks = []
for i in range(1,k_sec[3]+1):
use_3d = bool(i % 2)
blocks.append(("B%02d"%i, RESIDUAL_BLOCK(num_in=conv2_num_out if i==1 else conv3_num_out,
num_mid=num_mid,
num_out=conv3_num_out,
stride=(2,2,2) if i==1 else (1,1,1),
use_3d=use_3d,
g=groups,
first_block=(i==1))))
if i in [1,3]:
blocks.append(("B%02d_extra"%i, GloRe_Unit(num_in=conv3_num_out, num_mid=num_mid)))
self.conv3 = nn.Sequential(OrderedDict(blocks))
# conv4 - x14 (x8)
num_mid *= 2
conv4_num_out = 2 * conv3_num_out
blocks = []
for i in range(1,k_sec[4]+1):
use_3d = bool(i % 2)
blocks.append(("B%02d"%i, RESIDUAL_BLOCK(num_in=conv3_num_out if i==1 else conv4_num_out,
num_mid=num_mid,
num_out=conv4_num_out,
stride=(1,2,2) if i==1 else (1,1,1),
use_3d=use_3d,
g=groups,
first_block=(i==1))))
if i in [6,12,18]:
blocks.append(("B%02d_extra"%i, GloRe_Unit(num_in=conv4_num_out, num_mid=num_mid)))
self.conv4 = nn.Sequential(OrderedDict(blocks))
# conv5 - x7 (x4)
num_mid *= 2
conv5_num_out = 2 * conv4_num_out
self.conv5 = nn.Sequential(OrderedDict([
("B%02d"%i, RESIDUAL_BLOCK(num_in=conv4_num_out if i==1 else conv5_num_out,
num_mid=num_mid,
num_out=conv5_num_out,
stride=(1,2,2) if i==1 else (1,1,1),
g=groups,
use_3d=(i==2),
first_block=(i==1))) for i in range(1,k_sec[5]+1)
]))
# final
self.tail = nn.Sequential(OrderedDict([
('bn', nn.BatchNorm3d(conv5_num_out, eps=1e-04)),
('relu', nn.ReLU(inplace=True))
]))
self.globalpool = nn.Sequential(OrderedDict([
('avg', nn.AvgPool3d(kernel_size=(4,7,7), stride=(1,1,1))),
('dropout', nn.Dropout(p=0.5)),
]))
self.classifier = nn.Linear(conv5_num_out, num_classes)
#############
# Initialization
initializer.xavier(net=self)
if pretrained:
import torch
load_method='inflation' # 'random', 'inflation'
pretrained_model=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'pretrained/resnet101-lite.pth')
logging.info("Network:: symbol initialized, use pretrained model: `{}'".format(pretrained_model))
assert os.path.exists(pretrained_model), "cannot locate: `{}'".format(pretrained_model)
state_dict_2d = torch.load(pretrained_model)
initializer.init_3d_from_2d_dict(net=self, state_dict=state_dict_2d, method=load_method)
else:
logging.info("Network:: symbol initialized, use random inilization!")
blocker_name_list = []
for name, param in self.state_dict().items():
if name.endswith('blocker.weight'):
blocker_name_list.append(name)
param[:] = 0.
if len(blocker_name_list) > 0:
logging.info("Network:: change params of the following layer be zeros: {}".format(blocker_name_list))