def flops_count()

in pycls/utils/metrics.py [0:0]


def flops_count(model):
    """Computes the number of flops."""
    assert cfg.TRAIN.DATASET in ['cifar10', 'imagenet'], \
        'Computing flops for {} is not supported'.format(cfg.TRAIN.DATASET)
    im_size = 32 if cfg.TRAIN.DATASET == 'cifar10' else 224
    h, w = im_size, im_size
    count = 0
    for n, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            if '.se' in n:
                count += m.in_channels * m.out_channels + m.bias.numel()
                continue
            h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1
            w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1
            count += np.prod([
                m.weight.numel(),
                h_out, w_out
            ])
            if 'proj' not in n:
                h, w = h_out, w_out
        elif isinstance(m, TalkConv2d):
            h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1
            w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1
            count += int(np.prod([
                m.weight.numel()*m.flops_scale,
                h_out, w_out
            ]))
            if 'proj' not in n and 'pool' not in n:
                h, w = h_out, w_out
        elif isinstance(m, nn.MaxPool2d):
            h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1
            w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1
        elif isinstance(m, TalkLinear):
            count += int(m.in_features * m.out_features * m.flops_scale)
        elif isinstance(m, nn.Linear):
            count += m.in_features * m.out_features

    return count