in caffe2/python/core.py [0:0]
def _recover_record_by_prefix(names, prefix=''):
"""
Tries to recover record by taking a subset of blob names with
a given prefix name and interpreting them as schema column names
"""
from caffe2.python import schema
column_names = [name[len(prefix):] for name in names
if name.startswith(prefix)]
if not column_names:
return None
return schema.from_column_list(
column_names,
col_blobs=[_get_blob_ref(prefix + name) for name in column_names])
class Net(object):
_net_names_used = set()
operator_registry_ = {}
@staticmethod
def current_prefix():
from caffe2.python.net_builder import NetBuilder
builder = NetBuilder.current(required=False)
return builder.name if builder else ''
@staticmethod
def _get_next_net_name(basename):
name = basename = '/'.join(
x for x in [Net.current_prefix(), basename] if x
)
next_idx = 1
while name in Net._net_names_used:
name = basename + '_' + str(next_idx)
next_idx += 1
Net._net_names_used |= set([name])
return name
def __init__(self, name_or_proto, inplace=False):
"""
Create a Net.
Args:
name_or_proto: If a NetDef is provided, clone it (or take ownership,
depending on the value of `inplace`). Otherwise,
create an empty net with the given name.
inplace: If a NetDef is provided, take ownership when `inplace` is True;
otherwise, clone it.
"""
self._input_record = None
self._output_record = None
# Register blobs so that it's guaranteed that different calls to
# NextBlob/NextScopedBlob always return blobs with different names
self._registered_blob_names = set()
self._recreate_lookup_tables = False
self._op_outputs = set()
self._external_input_map = set()
self._attr_dict = defaultdict(list)
if type(name_or_proto) is caffe2_pb2.NetDef:
proto = name_or_proto
# We are initializing a network by a NetDef. In this case, we will
# initialize our network with the given netdef.
if inplace:
self._net = proto
else:
self._net = caffe2_pb2.NetDef()
self._net.CopyFrom(proto)
existing_outputs = [list(op.output) for op in self._net.op]
self._external_input_map.update(list(self._net.external_input))
# Set the next name index properly.
existing_names = set()
for op in self._net.op:
existing_names.update(list(op.input))
for output in existing_outputs:
existing_names.update(output)
for outs in existing_outputs:
self._op_outputs.update(outs)
prefix_len = len(self._net.name + '_blob_')
autogen_indices = []
for s in existing_names:
if s.startswith(self._net.name + '_blob_'):
try:
autogen_indices.append(int(s[prefix_len]))
except ValueError:
pass
if len(autogen_indices):
self._next_name_index = max(autogen_indices) + 1
else:
self._next_name_index = 0
name = self._net.name
else:
name = name_or_proto
self._net = caffe2_pb2.NetDef()
self._next_name_index = 0
# make sure that this net name hasn't been used before
self._net.name = Net._get_next_net_name(name)
# a map between prefix and ID for fast generation of blob names
self._next_blob_name_ids = {}
def AppendNet(self, net, device_option=None):
assert isinstance(net, Net)
for i in net.Proto().external_input:
if (
i not in self.Proto().external_input and
i not in self._op_outputs
):
self.Proto().external_input.append(i)
self.Proto().external_output.extend(
[
o for o in net.Proto().external_output
if o not in self.Proto().external_output
]
)
ops = net.Proto().op
if device_option is not None:
ops = [copy.deepcopy(op) for op in ops]
for op in ops:
op.device_option.CopyFrom(device_option)
for op in ops:
if op.type == "RecurrentNetwork":
for arg in op.arg:
if arg.name.endswith('step_net'):
for step_op in arg.n.op:
step_op.device_option.CopyFrom(device_option)
self._ExtendOps(ops)
return self
def LogInfo(self, *msg_or_blobs):
for msg_or_blob in msg_or_blobs:
if not isinstance(msg_or_blob, BlobReference):
blob = self.GivenTensorStringFill(
[], self.NextName('log'),
shape=[], values=[msg_or_blob])
else:
blob = msg_or_blob
self.Print(blob, [])
def add_attribute(self, name, obj):
"""
Add `obj` to the list of attributes in this net under the given `name`.
Attributes are user-defined objects and have no pre-defined semantics.
"""
self._attr_dict[name].append(obj)
def get_attributes(self, name):
"""
Returns the list of attributes in this net for a given `name`.
Attributes are user-defined objects added with `add_attribute'.
"""
return self._attr_dict.get(name, [])
def set_rand_seed(self, seed=100, sequence_seed=True, seed_on_op_def=False):
"""
Adds a random seed to each op in the net.
If sequence_seed is set, the i-th op has rand_seed=`seed + i`
If seed_on_op_def is set, the op rand_seed=hash(str(op))
sequence_seed and seed_on_op_def cannot be both set to True.
"""
assert not (sequence_seed and seed_on_op_def), (
'sequence_seed and seed_on_op_def cannot be both set to True.')
for i, op in enumerate(self.Proto().op):
if sequence_seed:
curr_seed = seed + i
elif seed_on_op_def:
curr_seed = hash(str(op) + str(seed)) % np.iinfo(np.uint32).max
else:
curr_seed = seed
op.device_option.random_seed = curr_seed
def Name(self):
return self._net.name
def __str__(self):
return self.Name()
def Const(self, array, blob_out=None, dtype=None):
if isinstance(array, bool):
return self.ConstantFill(
[],
blob_out or 1,
dtype=DataType.BOOL,
value=array)
if dtype is None:
array = np.array(array)
else:
array = np.array(array, dtype=dtype)
def do_set(operator):
return operator(
[],
blob_out or 1,
shape=array.shape,
values=array.flatten().tolist())
if array.dtype == np.int32:
return do_set(self.GivenTensorIntFill)
elif array.dtype == np.int64:
return do_set(self.GivenTensorInt64Fill)
elif array.dtype == np.str:
return do_set(self.GivenTensorStringFill)
elif array.dtype == np.bool:
return do_set(self.GivenTensorBoolFill)
else:
return do_set(self.GivenTensorFill)
def BlobIsDefined(self, blob):
"""
Returns true if the given BlobReference is produced as output of
an operator in this net, or if it is provided as an external input.
"""
if self._recreate_lookup_tables:
self._RecreateLookupTables()
name = str(blob)
return (name in self._op_outputs) or (name in self._external_input_map)
def UsesBlob(self, blob):
"""
Returns true iff the given BlobReference is used by any operator
or this net, or if it is one of the external inputs of the net.
"""
blob_name = str(blob)
for op in self._net.op:
for input in op.input:
if input == blob_name:
return True
return blob_name in self._external_input_map
def UsedBlobNames(self):
"""
Returns a set of blob names used in the net
"""
blob_names = set()
for op in self._net.op:
blob_names |= set(op.input)
blob_names |= set(op.output)
if self._net.external_input:
blob_names |= set(self._net.external_input)
if self._net.external_output:
blob_names |= set(self._net.external_output)
return blob_names
def GetBlobRef(self, blob_name):
"""
Given the name of a blob produced by this net, return a BlobReference
to it. If the blob is not produced by any op in this net,
raises KeyError.
"""
blob_name = str(blob_name)
if not self.BlobIsDefined(blob_name):
raise KeyError('Net does not define blob %s' % blob_name)
return BlobReference(blob_name, self)
def Clone(
self,
name,
blob_remap=None,
op_id_mask=None,
remap_funcs=None,
keep_schema=True,
update_external_list=False,
):
"""
Clone this net.
Args:
name: name of the cloned net
blob_remap: optional map with list of blob names to replace
op_id_mask: optional list of operator indices to include in
the cloned net. If not provided, all ops are included.
"""
orig_remap_funcs = {} if remap_funcs is None else remap_funcs
# by default we want to put RecurrentNetworkOp and
# RecurrentNetworkGradientOp into remap_funcs, as these two operators
# also take blobs and proto into the arguments.
remap_funcs = DEFAULT_REMAP_FUNCS.copy()
remap_funcs.update(orig_remap_funcs)
proto = self._net
new_proto = caffe2_pb2.NetDef()
new_proto.CopyFrom(proto)
new_proto.name = name
if blob_remap is None:
blob_remap = {}
if op_id_mask is None:
op_id_mask = list(range(0, len(proto.op)))
def get_remapped_str(blob):
blob_str = str(blob)
return str(blob_remap.get(blob_str, blob_str))
def remap_list(proto_list):
new_list = [get_remapped_str(b) for b in proto_list]
del proto_list[:]
proto_list.extend(new_list)
def remap_op(op):
new_op = caffe2_pb2.OperatorDef()
new_op.CopyFrom(op)
remap_list(new_op.input)
remap_list(new_op.output)
if new_op.type in remap_funcs:
remap_funcs[new_op.type](
new_op,
(name + '/') if name else '',
blob_remap,
)
return new_op
del new_proto.op[:]
new_proto.op.extend([remap_op(proto.op[op_id]) for op_id in op_id_mask])
remap_list(new_proto.external_input)
remap_list(new_proto.external_output)
new_net = Net(new_proto)
if keep_schema:
from caffe2.python import schema
if self._input_record:
new_net._input_record = schema.from_blob_list(
self._input_record,
[
BlobReference(get_remapped_str(blob), net=new_net)
for blob in self._input_record.field_blobs()
],
)
if self._output_record:
new_net._output_record = schema.from_blob_list(
self._output_record,
[
BlobReference(get_remapped_str(blob), net=new_net)
for blob in self._output_record.field_blobs()
],
)
new_net._attr_dict.update(self._attr_dict)
if update_external_list:
# external input list
existing_outputs = set()
used_outputs = set()
del new_net.Proto().external_input[:]
del new_net.Proto().external_output[:]
for op in new_net.Proto().op:
for ib in op.input:
if ib not in existing_outputs:
new_net.Proto().external_input.extend([ib])
else:
used_outputs.add(ib)
for ob in op.output:
existing_outputs.add(ob)
# external outputs
for ob in existing_outputs:
if ob not in used_outputs:
new_net.Proto().external_output.extend([ob])
return new_net
def ClonePartial(self, name, inputs, outputs, remap_funcs=None):
"""
Clone this net, including only ops that are necessary in order to
compute `outputs` given `inputs`. Return references to the cloned
outputs. Internal blobs (blobs that are produced and consumed inside
the net but not used as outputs) will be remapped to avoid name
conflict.
Args:
name: the name of the cloned net
inputs: map where the keys correspond to BlobReferences in the
original net, and the values correspond to external inputs
in the partially cloned net. If `inputs` is a list, don't
remap input names.
outputs: outputs to be produced by the cloned net.
Returns:
Tuple (new_net, new_outputs)
new_net: a new Net object.
new_outputs: list of BlobReferences corresponding to the
outputs produced by new_net.
"""
input_is_pair_list = isinstance(inputs, list) and all(
isinstance(i, tuple) and len(i) == 2 for i in inputs)
inputs = (
inputs if isinstance(inputs, (dict, OrderedDict)) else
OrderedDict(inputs) if input_is_pair_list else
OrderedDict(zip(inputs, inputs)))
for output in outputs:
assert self.BlobIsDefined(output), "{} is not defined".format(output)
input_names = {str(k): str(v) for k, v in viewitems(inputs)}
output_names = [str(o) for o in outputs]
proto = self._net
blob_versions = {str(i): 0 for i in inputs}
ssa, blob_versions = get_ssa(proto, blob_versions)
used_op_ids = get_op_ids_in_path(ssa, blob_versions, inputs, outputs)
disallowed_op_ids = get_op_ids_in_path(ssa, blob_versions, [], inputs)
assert len(set(used_op_ids) & set(disallowed_op_ids)) == 0, (
'Cannot partially clone net: some of the ops required would ' +
'generate the given input.')
sub_ssa = [op for i, op in enumerate(ssa) if i in used_op_ids]
undef_blobs = get_undefined_blobs(sub_ssa) - set(viewkeys(input_names))
prefix = (name + '/') if name else ''
def remap(blob_name):
if blob_name in input_names:
return input_names[blob_name]
elif blob_name in undef_blobs:
return blob_name
else:
return prefix + blob_name
blob_mapping = {b: remap(b) for b in viewkeys(blob_versions)}
new_net = self.Clone(name, blob_mapping, used_op_ids, remap_funcs)
new_in = [
blob_mapping[i] for i in viewkeys(input_names)] + list(undef_blobs)
new_out = [blob_mapping[o] for o in output_names]
del new_net.Proto().external_input[:]
new_net.Proto().external_input.extend(new_in)
new_net._external_input_map = set(list(new_in))
del new_net.Proto().external_output[:]
new_net.Proto().external_output.extend(new_out)
return new_net, [new_net.GetBlobRef(o) for o in new_out]
def Proto(self):
self._InvalidateLookupTables()
return self._net
def insert_op_at_idx(self, op, op_idx):
r""" inserting operator at index. Will update external blob list.
"""
assert op_idx >= 0
temp_ops = self.Proto().op[op_idx:]
del self.Proto().op[op_idx:]
self.Proto().op.extend([op])
self.Proto().op.extend(temp_ops)
self.external_outputs.extend(op.output)
self.external_inputs.extend(op.input)
def reroute_tensor(self, tensor, new_producer, can_modify=None):
r""" reroute tensor to new_producer. And feed new tensor to consumers
and interseciton with can_modify if provided.
Inputs:
tensor: str or blob_reference the tensor to reroute
new_producer: an op takes in tensor gives new_tesnor
can_modify: a list/set of operators that consumes tensor and can be
modified
Returns:
reroute_cnt: how many consumer op has been changed
Note: assume no inplace blob in net