in pycls/utils/metrics.py [0:0]
def flops_count(model):
"""Computes the number of flops."""
assert cfg.TRAIN.DATASET in ['cifar10', 'imagenet'], \
'Computing flops for {} is not supported'.format(cfg.TRAIN.DATASET)
im_size = 32 if cfg.TRAIN.DATASET == 'cifar10' else 224
h, w = im_size, im_size
count = 0
for n, m in model.named_modules():
if isinstance(m, nn.Conv2d):
if '.se' in n:
count += m.in_channels * m.out_channels + m.bias.numel()
continue
h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1
w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1
count += np.prod([
m.weight.numel(),
h_out, w_out
])
if 'proj' not in n:
h, w = h_out, w_out
elif isinstance(m, TalkConv2d):
h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1
w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1
count += int(np.prod([
m.weight.numel()*m.flops_scale,
h_out, w_out
]))
if 'proj' not in n and 'pool' not in n:
h, w = h_out, w_out
elif isinstance(m, nn.MaxPool2d):
h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1
w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1
elif isinstance(m, TalkLinear):
count += int(m.in_features * m.out_features * m.flops_scale)
elif isinstance(m, nn.Linear):
count += m.in_features * m.out_features
return count