in models/attentive_nas_dynamic_model.py [0:0]
def compute_active_subnet_flops(self):
def count_conv(c_in, c_out, size_out, groups, k):
kernel_ops = k**2
output_elements = c_out * size_out**2
ops = c_in * output_elements * kernel_ops / groups
return ops
def count_linear(c_in, c_out):
return c_in * c_out
total_ops = 0
c_in = 3
size_out = self.active_resolution // self.first_conv.stride
c_out = self.first_conv.active_out_channel
total_ops += count_conv(c_in, c_out, size_out, 1, 3)
c_in = c_out
# mb blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
block = self.blocks[idx]
c_middle = make_divisible(round(c_in * block.mobile_inverted_conv.active_expand_ratio), 8)
# 1*1 conv
if block.mobile_inverted_conv.inverted_bottleneck is not None:
total_ops += count_conv(c_in, c_middle, size_out, 1, 1)
# dw conv
stride = 1 if idx > active_idx[0] else block.mobile_inverted_conv.stride
if size_out % stride == 0:
size_out = size_out // stride
else:
size_out = (size_out +1) // stride
total_ops += count_conv(c_middle, c_middle, size_out, c_middle, block.mobile_inverted_conv.active_kernel_size)
# 1*1 conv
c_out = block.mobile_inverted_conv.active_out_channel
total_ops += count_conv(c_middle, c_out, size_out, 1, 1)
#se
if block.mobile_inverted_conv.use_se:
num_mid = make_divisible(c_middle // block.mobile_inverted_conv.depth_conv.se.reduction, divisor=8)
total_ops += count_conv(c_middle, num_mid, 1, 1, 1) * 2
if block.shortcut and c_in != c_out:
total_ops += count_conv(c_in, c_out, size_out, 1, 1)
c_in = c_out
if not self.use_v3_head:
c_out = self.last_conv.active_out_channel
total_ops += count_conv(c_in, c_out, size_out, 1, 1)
else:
c_expand = self.last_conv.final_expand_layer.active_out_channel
c_out = self.last_conv.feature_mix_layer.active_out_channel
total_ops += count_conv(c_in, c_expand, size_out, 1, 1)
total_ops += count_conv(c_expand, c_out, 1, 1, 1)
# n_classes
total_ops += count_linear(c_out, self.n_classes)
return total_ops / 1e6