in pycls/utils/metrics.py [0:0]
def measure_model(model, H, W):
global count_ops, count_params
count_ops = 0
count_params = 0
data = torch.zeros(1, 3, H, W).cuda()
def should_measure(x):
return is_leaf(x) or is_pruned(x)
def modify_forward(model):
for child in model.children():
if should_measure(child):
def new_forward(m):
def lambda_forward(x):
measure_layer(m, x)
return m.old_forward(x)
return lambda_forward
child.old_forward = child.forward
child.forward = new_forward(child)
else:
modify_forward(child)
def restore_forward(model):
for child in model.children():
# leaf node
if is_leaf(child) and hasattr(child, 'old_forward'):
child.forward = child.old_forward
child.old_forward = None
else:
restore_forward(child)
modify_forward(model)
model.forward(data)
restore_forward(model)
return count_ops, count_params