def get_flops_params()

in lib/utils/misc.py [0:0]


def get_flops_params(model):
    """
    Calculating flops and the number of parameters for Conv, FC, and
    BatchMatMul.
    """

    model_ops = model.net.Proto().op
    master_gpu = 'gpu_{}'.format(cfg.ROOT_GPU_ID)

    bs = get_batch_size(model.split)

    param_ops = []
    for idx in range(len(model_ops)):
        op_type = model.net.Proto().op[idx].type
        op_input = model.net.Proto().op[idx].input[0]
        if op_type in ['Conv', 'FC', 'BatchMatMul'] \
                and op_input.find(master_gpu) >= 0:
            param_ops.append(model.net.Proto().op[idx])

    num_flops = 0
    num_params = 0

    for idx in range(len(param_ops)):
        op = param_ops[idx]
        op_type = op.type
        op_inputs = param_ops[idx].input
        op_output = param_ops[idx].output[0]
        layer_flops = 0
        layer_params = 0
        correct_factor = 1

        if op_type == 'Conv':
            for op_input in op_inputs:
                if '_w' in op_input:
                    param_blob = op_input
                    param_shape = np.array(
                        workspace.FetchBlob(str(param_blob))).shape
                    layer_params = np.prod(param_shape)
                    output_shape = np.array(
                        workspace.FetchBlob(str(op_output))).shape
                    layer_flops = layer_params * np.prod(output_shape[2:])

                    if output_shape[0] > bs:
                        correct_factor = int(float(output_shape[0]) // bs)
                        layer_flops *= correct_factor
        elif op_type == 'FC':
            for op_input in op_inputs:
                if '_w' in op_input:
                    param_blob = op_input
                    param_shape = np.array(
                        workspace.FetchBlob(str(param_blob))).shape
                    output_shape = np.array(
                        workspace.FetchBlob(str(op_output))).shape

                    layer_params = np.prod(param_shape)
                    layer_flops = layer_params

                    if output_shape[0] > bs:
                        correct_factor = int(float(output_shape[0]) // bs)
                        layer_flops *= correct_factor

        elif op_type == 'BatchMatMul':
            if 'grad' in op_inputs[0] or 'grad' in op_inputs[1]:
                continue
            if 'shared' in op_inputs[0] or 'shared' in op_inputs[1]:
                continue

            if op.is_gradient_op:
                continue

            param_shape_a = np.array(
                workspace.FetchBlob(str(op_inputs[0]))).shape
            param_shape_b = np.array(
                workspace.FetchBlob(str(op_inputs[1]))).shape

            output_shape = np.array(
                workspace.FetchBlob(str(op_output))).shape

            correct_factor = output_shape[0] // bs

            param_shape_a = param_shape_a[1:]
            param_shape_b = param_shape_b[1:]

            if op.arg[0].name == 'trans_a':
                param_shape_a = param_shape_a[::-1]
            elif op.arg[0].name == 'trans_b':
                param_shape_b = param_shape_b[::-1]
            else:
                raise NotImplementedError('trans_a or trans_b')

            layer_flops = param_shape_a[0] * param_shape_a[1] \
                * param_shape_b[1] * correct_factor

        logger.info('layer {} ({}) FLOPs: {:.2f} M PARAMs: {:.2f} K'.format(
                    op.output[0], correct_factor,
                    layer_flops / 1e6, layer_params / 1e3))

        num_flops += layer_flops
        num_params += layer_params
    return num_flops, num_params