in self_supervision_benchmark/utils/helpers.py [0:0]
def get_flops_params(model, device_id=0):
model_ops = model.net.Proto().op
prefix, _ = get_prefix_and_device()
master_device = prefix + str(device_id)
param_ops = []
for idx in range(len(model_ops)):
op_type = model.net.Proto().op[idx].type
op_input = model.net.Proto().op[idx].input[0]
if op_type in ['Conv', 'FC'] and op_input.find(master_device) >= 0:
param_ops.append(model.net.Proto().op[idx])
num_flops = 0
num_params = 0
for idx in range(len(param_ops)):
op = param_ops[idx]
op_type = op.type
op_inputs = param_ops[idx].input
op_output = param_ops[idx].output[0]
layer_flops = 0
layer_params = 0
if op_type == 'Conv':
for op_input in op_inputs:
if '_w' in op_input:
param_blob = op_input
param_sh = np.array(
workspace.FetchBlob(str(param_blob))).shape
layer_params = (
param_sh[0] * param_sh[1] * param_sh[2] * param_sh[3]
)
output_shape = np.array(
workspace.FetchBlob(str(op_output))).shape
layer_flops = layer_params * output_shape[2] * output_shape[3]
elif op_type == 'FC':
for op_input in op_inputs:
if '_w' in op_input:
param_blob = op_input
param_shape = np.array(
workspace.FetchBlob(str(param_blob))).shape
layer_params = param_shape[0] * param_shape[1]
layer_flops = layer_params
layer_params /= 1000000
layer_flops /= 1000000000
num_flops += layer_flops
num_params += layer_params
logger.info('Total network FLOPs (10^9): {}'.format(num_flops))
logger.info('Total network params (10^6): {}'.format(num_params))
return num_flops, num_params