in runtime_lut/code/api.py [0:0]
def get_ops_from_net(model, blobs, input_dims):
# Extract all operators and corresponding input shapes from a model
blobs = {x: blobs[x] for x in blobs if not isinstance(blobs[x], str)}
blobs.update({x: np.ones(input_dims[x], dtype=np.float32) for x in input_dims})
blobs_dims, blobs_dtypes = model_utils.infer_model_shape_by_ops(
model, extra_inputs=blobs, get_dtype=True
)
blobs = {x: np.zeros(blobs_dims[x], dtype=blobs_dtypes[x]) for x in blobs_dims}
if type(model) == pb2.NetDef:
proto = model
else:
proto = model.Proto()
model_ops = proto.op
ops, input_shapes, input_dtypes = [], [], []
for op in model_ops:
op_type, op_inputs = op.type, op.input
if op_type in ["Conv", "FC"]:
assert len(op_inputs) == 2 or len(op_inputs) == 3
param_shape = [
np.array(blobs[str(param_blob)]).shape for param_blob in op_inputs
]
param_dtypes = [
encode_dtype(str(blobs[str(param_blob)].dtype)) for param_blob in op_inputs
]
input_shapes.append(param_shape)
input_dtypes.append(param_dtypes)
ops.append(op)
return ops, input_shapes, input_dtypes