in runtime_lut/code/model_utils.py [0:0]
def create_fill_op(name, blob, device_option=None):
"""Create an operator to store the tensor 'blob',
return the operator
"""
kTypeNameMapper = {
np.dtype("float32"): "GivenTensorFill",
np.dtype("int32"): "GivenTensorIntFill",
np.dtype("int64"): "GivenTensorInt64Fill",
np.dtype("uint8"): "GivenTensorStringFill",
workspace.Int8Tensor: {
np.dtype("int32"): "Int8GivenIntTensorFill",
np.dtype("uint8"): "Int8GivenTensorFill",
},
}
try:
blob_type = blob.dtype
except AttributeError:
blob_type = type(blob)
except Exception as e:
print("Error when geting blob type {}: {}\n{}".format(name, blob, e))
raise
op_type = kTypeNameMapper[blob_type]
args_dict = {}
if blob_type == np.dtype("uint8"):
args_dict.update({"values": [str(blob.data)], "shape": [1]})
elif blob_type == workspace.Int8Tensor:
data_type = blob.data.dtype
shape = blob.data.shape
assert data_type in [np.dtype("uint8"), np.dtype("int32")]
op_type = op_type[data_type]
values = blob.data
scale = blob.scale
zero_point = blob.zero_point
if data_type == np.dtype("uint8"):
values = values.tobytes()
args_dict.update(
{
"values": values,
"shape": shape,
"Y_scale": scale,
"Y_zero_point": zero_point,
}
)
else:
args_dict.update({"values": blob, "shape": blob.shape})
if device_option is not None:
args_dict["device_option"] = device_option
op = core.CreateOperator(op_type, [], [name], **args_dict)
return op