in eval_linear.py [0:0]
def __init__(self, num_labels, arch="resnet50", global_avg=False, use_bn=True):
super(RegLog, self).__init__()
self.bn = None
if global_avg:
if arch == "resnet50":
s = 2048
elif arch == "resnet50w2":
s = 4096
elif arch == "resnet50w4":
s = 8192
self.av_pool = nn.AdaptiveAvgPool2d((1, 1))
else:
assert arch == "resnet50"
s = 8192
self.av_pool = nn.AvgPool2d(6, stride=1)
if use_bn:
self.bn = nn.BatchNorm2d(2048)
self.linear = nn.Linear(s, num_labels)
self.linear.weight.data.normal_(mean=0.0, std=0.01)
self.linear.bias.data.zero_()