def get_ops_from_net()

in runtime_lut/code/api.py [0:0]


def get_ops_from_net(model, blobs, input_dims):
    # Extract all operators and corresponding input shapes from a model
    blobs = {x: blobs[x] for x in blobs if not isinstance(blobs[x], str)}
    blobs.update({x: np.ones(input_dims[x], dtype=np.float32) for x in input_dims})

    blobs_dims, blobs_dtypes = model_utils.infer_model_shape_by_ops(
        model, extra_inputs=blobs, get_dtype=True
    )

    blobs = {x: np.zeros(blobs_dims[x], dtype=blobs_dtypes[x]) for x in blobs_dims}

    if type(model) == pb2.NetDef:
        proto = model
    else:
        proto = model.Proto()

    model_ops = proto.op

    ops, input_shapes, input_dtypes = [], [], []
    for op in model_ops:
        op_type, op_inputs = op.type, op.input

        if op_type in ["Conv", "FC"]:
            assert len(op_inputs) == 2 or len(op_inputs) == 3

        param_shape = [
            np.array(blobs[str(param_blob)]).shape for param_blob in op_inputs
        ]
        param_dtypes = [
            encode_dtype(str(blobs[str(param_blob)].dtype)) for param_blob in op_inputs
        ]

        input_shapes.append(param_shape)
        input_dtypes.append(param_dtypes)
        ops.append(op)

    return ops, input_shapes, input_dtypes