def _recover_record_by_prefix()

in caffe2/python/core.py [0:0]


def _recover_record_by_prefix(names, prefix=''):
    """
    Tries to recover record by taking a subset of blob names with
    a given prefix name and interpreting them as schema column names
    """
    from caffe2.python import schema
    column_names = [name[len(prefix):] for name in names
                    if name.startswith(prefix)]
    if not column_names:
        return None
    return schema.from_column_list(
        column_names,
        col_blobs=[_get_blob_ref(prefix + name) for name in column_names])


class Net(object):
    _net_names_used = set()
    operator_registry_ = {}

    @staticmethod
    def current_prefix():
        from caffe2.python.net_builder import NetBuilder
        builder = NetBuilder.current(required=False)
        return builder.name if builder else ''

    @staticmethod
    def _get_next_net_name(basename):
        name = basename = '/'.join(
            x for x in [Net.current_prefix(), basename] if x
        )
        next_idx = 1
        while name in Net._net_names_used:
            name = basename + '_' + str(next_idx)
            next_idx += 1
        Net._net_names_used |= set([name])
        return name

    def __init__(self, name_or_proto, inplace=False):
        """
        Create a Net.
        Args:
            name_or_proto:  If a NetDef is provided, clone it (or take ownership,
                            depending on the value of `inplace`). Otherwise,
                            create an empty net with the given name.
            inplace: If a NetDef is provided, take ownership when `inplace` is True;
                     otherwise, clone it.
        """
        self._input_record = None
        self._output_record = None
        # Register blobs so that it's guaranteed that different calls to
        # NextBlob/NextScopedBlob always return blobs with different names
        self._registered_blob_names = set()
        self._recreate_lookup_tables = False
        self._op_outputs = set()
        self._external_input_map = set()
        self._attr_dict = defaultdict(list)
        if type(name_or_proto) is caffe2_pb2.NetDef:
            proto = name_or_proto
            # We are initializing a network by a NetDef. In this case, we will
            # initialize our network with the given netdef.
            if inplace:
                self._net = proto
            else:
                self._net = caffe2_pb2.NetDef()
                self._net.CopyFrom(proto)

            existing_outputs = [list(op.output) for op in self._net.op]

            self._external_input_map.update(list(self._net.external_input))

            # Set the next name index properly.
            existing_names = set()
            for op in self._net.op:
                existing_names.update(list(op.input))
            for output in existing_outputs:
                existing_names.update(output)

            for outs in existing_outputs:
                self._op_outputs.update(outs)

            prefix_len = len(self._net.name + '_blob_')
            autogen_indices = []
            for s in existing_names:
                if s.startswith(self._net.name + '_blob_'):
                    try:
                        autogen_indices.append(int(s[prefix_len]))
                    except ValueError:
                        pass
            if len(autogen_indices):
                self._next_name_index = max(autogen_indices) + 1
            else:
                self._next_name_index = 0
            name = self._net.name
        else:
            name = name_or_proto
            self._net = caffe2_pb2.NetDef()
            self._next_name_index = 0

        # make sure that this net name hasn't been used before
        self._net.name = Net._get_next_net_name(name)

        # a map between prefix and ID for fast generation of blob names
        self._next_blob_name_ids = {}


    def AppendNet(self, net, device_option=None):
        assert isinstance(net, Net)
        for i in net.Proto().external_input:
            if (
                i not in self.Proto().external_input and
                i not in self._op_outputs
            ):
                self.Proto().external_input.append(i)

        self.Proto().external_output.extend(
            [
                o for o in net.Proto().external_output
                if o not in self.Proto().external_output
            ]
        )
        ops = net.Proto().op
        if device_option is not None:
            ops = [copy.deepcopy(op) for op in ops]
            for op in ops:
                op.device_option.CopyFrom(device_option)
            for op in ops:
                if op.type == "RecurrentNetwork":
                    for arg in op.arg:
                        if arg.name.endswith('step_net'):
                            for step_op in arg.n.op:
                                step_op.device_option.CopyFrom(device_option)

        self._ExtendOps(ops)
        return self

    def LogInfo(self, *msg_or_blobs):
        for msg_or_blob in msg_or_blobs:
            if not isinstance(msg_or_blob, BlobReference):
                blob = self.GivenTensorStringFill(
                    [], self.NextName('log'),
                    shape=[], values=[msg_or_blob])
            else:
                blob = msg_or_blob
            self.Print(blob, [])

    def add_attribute(self, name, obj):
        """
        Add `obj` to the list of attributes in this net under the given `name`.
        Attributes are user-defined objects and have no pre-defined semantics.
        """
        self._attr_dict[name].append(obj)

    def get_attributes(self, name):
        """
        Returns the list of attributes in this net for a given `name`.
        Attributes are user-defined objects added with `add_attribute'.
        """
        return self._attr_dict.get(name, [])

    def set_rand_seed(self, seed=100, sequence_seed=True, seed_on_op_def=False):
        """
        Adds a random seed to each op in the net.
        If sequence_seed is set, the i-th op has rand_seed=`seed + i`
        If seed_on_op_def is set, the op rand_seed=hash(str(op))
        sequence_seed and seed_on_op_def cannot be both set to True.
        """
        assert not (sequence_seed and seed_on_op_def), (
            'sequence_seed and seed_on_op_def cannot be both set to True.')
        for i, op in enumerate(self.Proto().op):
            if sequence_seed:
                curr_seed = seed + i
            elif seed_on_op_def:
                curr_seed = hash(str(op) + str(seed)) % np.iinfo(np.uint32).max
            else:
                curr_seed = seed
            op.device_option.random_seed = curr_seed

    def Name(self):
        return self._net.name

    def __str__(self):
        return self.Name()

    def Const(self, array, blob_out=None, dtype=None):
        if isinstance(array, bool):
            return self.ConstantFill(
                [],
                blob_out or 1,
                dtype=DataType.BOOL,
                value=array)

        if dtype is None:
            array = np.array(array)
        else:
            array = np.array(array, dtype=dtype)

        def do_set(operator):
            return operator(
                [],
                blob_out or 1,
                shape=array.shape,
                values=array.flatten().tolist())

        if array.dtype == np.int32:
            return do_set(self.GivenTensorIntFill)
        elif array.dtype == np.int64:
            return do_set(self.GivenTensorInt64Fill)
        elif array.dtype == np.str:
            return do_set(self.GivenTensorStringFill)
        elif array.dtype == np.bool:
            return do_set(self.GivenTensorBoolFill)
        else:
            return do_set(self.GivenTensorFill)

    def BlobIsDefined(self, blob):
        """
        Returns true if the given BlobReference is produced as output of
        an operator in this net, or if it is provided as an external input.
        """
        if self._recreate_lookup_tables:
            self._RecreateLookupTables()
        name = str(blob)
        return (name in self._op_outputs) or (name in self._external_input_map)

    def UsesBlob(self, blob):
        """
        Returns true iff the given BlobReference is used by any operator
        or this net, or if it is one of the external inputs of the net.
        """
        blob_name = str(blob)
        for op in self._net.op:
            for input in op.input:
                if input == blob_name:
                    return True
        return blob_name in self._external_input_map

    def UsedBlobNames(self):
        """
        Returns a set of blob names used in the net
        """
        blob_names = set()
        for op in self._net.op:
            blob_names |= set(op.input)
            blob_names |= set(op.output)
        if self._net.external_input:
            blob_names |= set(self._net.external_input)
        if self._net.external_output:
            blob_names |= set(self._net.external_output)
        return blob_names

    def GetBlobRef(self, blob_name):
        """
        Given the name of a blob produced by this net, return a BlobReference
        to it. If the blob is not produced by any op in this net,
        raises KeyError.
        """
        blob_name = str(blob_name)
        if not self.BlobIsDefined(blob_name):
            raise KeyError('Net does not define blob %s' % blob_name)
        return BlobReference(blob_name, self)

    def Clone(
        self,
        name,
        blob_remap=None,
        op_id_mask=None,
        remap_funcs=None,
        keep_schema=True,
        update_external_list=False,
    ):
        """
        Clone this net.
        Args:
            name:        name of the cloned net
            blob_remap:  optional map with list of blob names to replace
            op_id_mask:  optional list of operator indices to include in
                         the cloned net. If not provided, all ops are included.
        """
        orig_remap_funcs = {} if remap_funcs is None else remap_funcs
        # by default we want to put RecurrentNetworkOp and
        # RecurrentNetworkGradientOp into remap_funcs, as these two operators
        # also take blobs and proto into the arguments.
        remap_funcs = DEFAULT_REMAP_FUNCS.copy()
        remap_funcs.update(orig_remap_funcs)
        proto = self._net
        new_proto = caffe2_pb2.NetDef()
        new_proto.CopyFrom(proto)
        new_proto.name = name

        if blob_remap is None:
            blob_remap = {}
        if op_id_mask is None:
            op_id_mask = list(range(0, len(proto.op)))

        def get_remapped_str(blob):
            blob_str = str(blob)
            return str(blob_remap.get(blob_str, blob_str))

        def remap_list(proto_list):
            new_list = [get_remapped_str(b) for b in proto_list]
            del proto_list[:]
            proto_list.extend(new_list)

        def remap_op(op):
            new_op = caffe2_pb2.OperatorDef()
            new_op.CopyFrom(op)
            remap_list(new_op.input)
            remap_list(new_op.output)
            if new_op.type in remap_funcs:
                remap_funcs[new_op.type](
                    new_op,
                    (name + '/') if name else '',
                    blob_remap,
                )
            return new_op

        del new_proto.op[:]
        new_proto.op.extend([remap_op(proto.op[op_id]) for op_id in op_id_mask])
        remap_list(new_proto.external_input)
        remap_list(new_proto.external_output)
        new_net = Net(new_proto)

        if keep_schema:
            from caffe2.python import schema
            if self._input_record:
                new_net._input_record = schema.from_blob_list(
                    self._input_record,
                    [
                        BlobReference(get_remapped_str(blob), net=new_net)
                        for blob in self._input_record.field_blobs()
                    ],
                )
            if self._output_record:
                new_net._output_record = schema.from_blob_list(
                    self._output_record,
                    [
                        BlobReference(get_remapped_str(blob), net=new_net)
                        for blob in self._output_record.field_blobs()
                    ],
                )

        new_net._attr_dict.update(self._attr_dict)
        if update_external_list:
            # external input list
            existing_outputs = set()
            used_outputs = set()
            del new_net.Proto().external_input[:]
            del new_net.Proto().external_output[:]
            for op in new_net.Proto().op:
                for ib in op.input:
                    if ib not in existing_outputs:
                        new_net.Proto().external_input.extend([ib])
                    else:
                        used_outputs.add(ib)
                for ob in op.output:
                    existing_outputs.add(ob)
            # external outputs
            for ob in existing_outputs:
                if ob not in used_outputs:
                    new_net.Proto().external_output.extend([ob])
        return new_net

    def ClonePartial(self, name, inputs, outputs, remap_funcs=None):
        """
        Clone this net, including only ops that are necessary in order to
        compute `outputs` given `inputs`. Return references to the cloned
        outputs. Internal blobs (blobs that are produced and consumed inside
        the net but not used as outputs) will be remapped to avoid name
        conflict.

        Args:
            name:    the name of the cloned net
            inputs:  map where the keys correspond to BlobReferences in the
                     original net, and the values correspond to external inputs
                     in the partially cloned net. If `inputs` is a list, don't
                     remap input names.
            outputs: outputs to be produced by the cloned net.

        Returns:
            Tuple (new_net, new_outputs)
                new_net:       a new Net object.
                new_outputs:   list of BlobReferences corresponding to the
                               outputs produced by new_net.
        """
        input_is_pair_list = isinstance(inputs, list) and all(
            isinstance(i, tuple) and len(i) == 2 for i in inputs)
        inputs = (
            inputs if isinstance(inputs, (dict, OrderedDict)) else
            OrderedDict(inputs) if input_is_pair_list else
            OrderedDict(zip(inputs, inputs)))
        for output in outputs:
            assert self.BlobIsDefined(output), "{} is not defined".format(output)
        input_names = {str(k): str(v) for k, v in viewitems(inputs)}
        output_names = [str(o) for o in outputs]
        proto = self._net
        blob_versions = {str(i): 0 for i in inputs}
        ssa, blob_versions = get_ssa(proto, blob_versions)
        used_op_ids = get_op_ids_in_path(ssa, blob_versions, inputs, outputs)
        disallowed_op_ids = get_op_ids_in_path(ssa, blob_versions, [], inputs)
        assert len(set(used_op_ids) & set(disallowed_op_ids)) == 0, (
            'Cannot partially clone net: some of the ops required would ' +
            'generate the given input.')

        sub_ssa = [op for i, op in enumerate(ssa) if i in used_op_ids]
        undef_blobs = get_undefined_blobs(sub_ssa) - set(viewkeys(input_names))
        prefix = (name + '/') if name else ''

        def remap(blob_name):
            if blob_name in input_names:
                return input_names[blob_name]
            elif blob_name in undef_blobs:
                return blob_name
            else:
                return prefix + blob_name

        blob_mapping = {b: remap(b) for b in viewkeys(blob_versions)}
        new_net = self.Clone(name, blob_mapping, used_op_ids, remap_funcs)
        new_in = [
            blob_mapping[i] for i in viewkeys(input_names)] + list(undef_blobs)
        new_out = [blob_mapping[o] for o in output_names]
        del new_net.Proto().external_input[:]
        new_net.Proto().external_input.extend(new_in)
        new_net._external_input_map = set(list(new_in))
        del new_net.Proto().external_output[:]
        new_net.Proto().external_output.extend(new_out)
        return new_net, [new_net.GetBlobRef(o) for o in new_out]

    def Proto(self):
        self._InvalidateLookupTables()
        return self._net

    def insert_op_at_idx(self, op, op_idx):
        r""" inserting operator at index. Will update external blob list.
        """
        assert op_idx >= 0
        temp_ops = self.Proto().op[op_idx:]
        del self.Proto().op[op_idx:]
        self.Proto().op.extend([op])
        self.Proto().op.extend(temp_ops)
        self.external_outputs.extend(op.output)
        self.external_inputs.extend(op.input)

    def reroute_tensor(self, tensor, new_producer, can_modify=None):
        r""" reroute tensor to new_producer. And feed new tensor to consumers
        and interseciton with can_modify if provided.
        Inputs:
            tensor: str or blob_reference the tensor to reroute
            new_producer: an op takes in tensor gives new_tesnor
            can_modify: a list/set of operators that consumes tensor and can be
            modified

        Returns:
            reroute_cnt: how many consumer op has been changed

        Note: assume no inplace blob in net