in utils/flops_counter.py [0:0]
def profile(model, input_size=(1, 3, 224, 224), custom_ops=None):
handler_collection = []
custom_ops = {} if custom_ops is None else custom_ops
def add_hooks(m_):
if len(list(m_.children())) > 0:
return
m_.register_buffer('total_ops', torch.zeros(1))
m_.register_buffer('total_params', torch.zeros(1))
for p in m_.parameters():
m_.total_params += torch.Tensor([p.numel()])
m_type = type(m_)
fn = None
if m_type in custom_ops:
fn = custom_ops[m_type]
elif m_type in register_hooks:
fn = register_hooks[m_type]
else:
# print("Not implemented for ", m_)
pass
if fn is not None:
# print("Register FLOP counter for module %s" % str(m_))
_handler = m_.register_forward_hook(fn)
handler_collection.append(_handler)
original_device = model.parameters().__next__().device
training = model.training
model.eval()
model.apply(add_hooks)
x = torch.zeros(input_size).to(original_device)
with torch.no_grad():
model(x)
total_ops = 0
total_params = 0
for m in model.modules():
if len(list(m.children())) > 0: # skip for non-leaf module
continue
total_ops += m.total_ops
total_params += m.total_params
total_ops = total_ops.item()
total_params = total_params.item()
model.train(training)
model.to(original_device)
for handler in handler_collection:
handler.remove()
return total_ops, total_params