in classy_vision/generic/profiler.py [0:0]
def _layer_flops(layer: nn.Module, layer_args: List[Any], y: Any) -> int:
"""
Computes the number of FLOPs required for a single layer.
For common layers, such as Conv1d, the flop compute is implemented in this
centralized place.
For other layers, if it defines a method to compute flops with the signature
below, we will use it to compute flops.
Class MyModule(nn.Module):
def flops(self, x):
...
"""
x = layer_args[0]
# get layer type:
typestr = layer.__repr__()
layer_type = typestr[: typestr.find("(")].strip()
batchsize_per_replica = get_batchsize_per_replica(x)
flops = None
# 1D convolution:
if layer_type in ["Conv1d"]:
# x shape is N x C x W
out_w = int(
(x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0])
/ layer.stride[0]
+ 1
)
flops = (
batchsize_per_replica
* layer.in_channels
* layer.out_channels
* layer.kernel_size[0]
* out_w
/ layer.groups
)
# 2D convolution:
elif layer_type in ["Conv2d"]:
out_h = int(
(x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0])
/ layer.stride[0]
+ 1
)
out_w = int(
(x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1])
/ layer.stride[1]
+ 1
)
flops = (
batchsize_per_replica
* layer.in_channels
* layer.out_channels
* layer.kernel_size[0]
* layer.kernel_size[1]
* out_h
* out_w
/ layer.groups
)
# learned group convolution:
elif layer_type in ["LearnedGroupConv"]:
conv = layer.conv
out_h = int(
(x.size()[2] + 2 * conv.padding[0] - conv.kernel_size[0]) / conv.stride[0]
+ 1
)
out_w = int(
(x.size()[3] + 2 * conv.padding[1] - conv.kernel_size[1]) / conv.stride[1]
+ 1
)
count1 = _layer_flops(layer.relu, x) + _layer_flops(layer.norm, x)
count2 = (
batchsize_per_replica
* conv.in_channels
* conv.out_channels
* conv.kernel_size[0]
* conv.kernel_size[1]
* out_h
* out_w
/ layer.condense_factor
)
flops = count1 + count2
# non-linearities:
elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax", "SiLU"]:
flops = x.numel()
# 2D pooling layers:
elif layer_type in ["AvgPool2d", "MaxPool2d"]:
in_h = x.size()[2]
in_w = x.size()[3]
if isinstance(layer.kernel_size, int):
layer.kernel_size = (layer.kernel_size, layer.kernel_size)
kernel_ops = layer.kernel_size[0] * layer.kernel_size[1]
out_h = 1 + int(
(in_h + 2 * layer.padding - layer.kernel_size[0]) / layer.stride
)
out_w = 1 + int(
(in_w + 2 * layer.padding - layer.kernel_size[1]) / layer.stride
)
flops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops
# adaptive avg pool2d
# This is approximate and works only for downsampling without padding
# based on aten/src/ATen/native/AdaptiveAveragePooling.cpp
elif layer_type in ["AdaptiveAvgPool2d"]:
in_h = x.size()[2]
in_w = x.size()[3]
if isinstance(layer.output_size, int):
out_h, out_w = layer.output_size, layer.output_size
elif len(layer.output_size) == 1:
out_h, out_w = layer.output_size[0], layer.output_size[0]
else:
out_h, out_w = layer.output_size
if out_h > in_h or out_w > in_w:
raise ClassyProfilerNotImplementedError(layer)
batchsize_per_replica = x.size()[0]
num_channels = x.size()[1]
kh = in_h - out_h + 1
kw = in_w - out_w + 1
kernel_ops = kh * kw
flops = batchsize_per_replica * num_channels * out_h * out_w * kernel_ops
# linear layer:
elif layer_type in ["Linear"]:
weight_ops = layer.weight.numel()
bias_ops = layer.bias.numel() if layer.bias is not None else 0
flops = ((x.numel() / x.size(-1)) if x.ndim > 2 else x.size(0)) * (
weight_ops + bias_ops
)
# batch normalization / layer normalization:
elif layer_type in [
"BatchNorm1d",
"BatchNorm2d",
"BatchNorm3d",
"SyncBatchNorm",
"LayerNorm",
]:
flops = 2 * x.numel()
# 3D convolution
elif layer_type in ["Conv3d"]:
out_t = int(
(x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0])
// layer.stride[0]
+ 1
)
out_h = int(
(x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1])
// layer.stride[1]
+ 1
)
out_w = int(
(x.size()[4] + 2 * layer.padding[2] - layer.kernel_size[2])
// layer.stride[2]
+ 1
)
flops = (
batchsize_per_replica
* layer.in_channels
* layer.out_channels
* layer.kernel_size[0]
* layer.kernel_size[1]
* layer.kernel_size[2]
* out_t
* out_h
* out_w
/ layer.groups
)
# 3D pooling layers
elif layer_type in ["AvgPool3d", "MaxPool3d"]:
in_t = x.size()[2]
in_h = x.size()[3]
in_w = x.size()[4]
if isinstance(layer.kernel_size, int):
layer.kernel_size = (
layer.kernel_size,
layer.kernel_size,
layer.kernel_size,
)
if isinstance(layer.padding, int):
layer.padding = (layer.padding, layer.padding, layer.padding)
if isinstance(layer.stride, int):
layer.stride = (layer.stride, layer.stride, layer.stride)
kernel_ops = layer.kernel_size[0] * layer.kernel_size[1] * layer.kernel_size[2]
out_t = 1 + int(
(in_t + 2 * layer.padding[0] - layer.kernel_size[0]) / layer.stride[0]
)
out_h = 1 + int(
(in_h + 2 * layer.padding[1] - layer.kernel_size[1]) / layer.stride[1]
)
out_w = 1 + int(
(in_w + 2 * layer.padding[2] - layer.kernel_size[2]) / layer.stride[2]
)
flops = batchsize_per_replica * x.size()[1] * out_t * out_h * out_w * kernel_ops
# adaptive avg pool3d
# This is approximate and works only for downsampling without padding
# based on aten/src/ATen/native/AdaptiveAveragePooling3d.cpp
elif layer_type in ["AdaptiveAvgPool3d"]:
in_t = x.size()[2]
in_h = x.size()[3]
in_w = x.size()[4]
out_t = layer.output_size[0]
out_h = layer.output_size[1]
out_w = layer.output_size[2]
if out_t > in_t or out_h > in_h or out_w > in_w:
raise ClassyProfilerNotImplementedError(layer)
batchsize_per_replica = x.size()[0]
num_channels = x.size()[1]
kt = in_t - out_t + 1
kh = in_h - out_h + 1
kw = in_w - out_w + 1
kernel_ops = kt * kh * kw
flops = (
batchsize_per_replica * num_channels * out_t * out_w * out_h * kernel_ops
)
elif layer_type in ["Dropout", "Identity"]:
flops = 0
elif hasattr(layer, "flops"):
# If the module already defines a method to compute flops with the signature
# below, we use it to compute flops
#
# Class MyModule(nn.Module):
# def flops(self, x):
# ...
# or
#
# Class MyModule(nn.Module):
# def flops(self, x1, x2):
# ...
flops = layer.flops(*layer_args)
if flops is None:
raise ClassyProfilerNotImplementedError(layer)
message = [
f"module type: {typestr}",
f"input size: {get_shape(x)}",
f"output size: {get_shape(y)}",
f"params(M): {count_params(layer) / 1e6}",
f"flops(M): {int(flops) / 1e6}",
]
logging.debug("\t".join(message))
return int(flops)