def clear_save_assign()

in easy_rec/python/utils/meta_graph_editor.py [0:0]


  def clear_save_assign(self):
    logging.info(
        'update save/Assign, drop tensor_shapes, _output_shapes, related to embedding_weights'
    )
    # edit save/Assign
    drop_save_assigns = []
    all_kv_drop = []
    for tid, node in enumerate(self._all_graph_nodes):
      if not self._all_graph_node_flags[tid]:
        continue
      if node.op == 'Assign' and 'save/Assign' in node.name and \
         'embedding_weights' in node.input[0]:
        drop_save_assigns.append('^' + node.name)
        self._all_graph_node_flags[tid] = False
      elif 'embedding_weights/ConcatPartitions/concat' in node.name:
        self._all_graph_node_flags[tid] = False
      elif node.name.endswith('/embedding_weights') and node.op == 'Identity':
        self._all_graph_node_flags[tid] = False
      elif 'save/KvResourceImportV2' in node.name and node.op == 'KvResourceImportV2':
        drop_save_assigns.append('^' + node.name)
        self._all_graph_node_flags[tid] = False
      elif 'KvResourceImportV2' in node.name:
        self._all_graph_node_flags[tid] = False
      elif 'save/Const' in node.name and node.op == 'Const':
        if '_class' in node.attr and len(node.attr['_class'].list.s) > 0:
          const_name = node.attr['_class'].list.s[0]
          if not isinstance(const_name, str):
            const_name = const_name.decode('utf-8')
          if 'embedding_weights' in const_name:
            self._all_graph_node_flags[tid] = False
      elif 'ReadKvVariableOp' in node.name and node.op == 'ReadKvVariableOp':
        all_kv_drop.append(node.name)
        self._all_graph_node_flags[tid] = False
      elif node.op == 'Assign' and 'save/Assign' in node.name:
        # update node(save/Assign_[0-N])'s input[1] by the position of
        #     node.input[0] in save/RestoreV2/tensor_names
        # the outputs of save/RestoreV2 is connected to save/Assign
        tmp_id = [
            self.bytes2str(x)
            for x in self._restore_tensor_node.attr['value'].tensor.string_val
        ].index(node.input[0])
        if tmp_id != 0:
          tmp_input2 = 'save/RestoreV2:%d' % tmp_id
        else:
          tmp_input2 = 'save/RestoreV2'
        if tmp_input2 != node.input[1]:
          if self._verbose:
            logging.info("update save/Assign[%s]'s input from %s to %s" %
                         (node.name, node.input[1], tmp_input2))
          node.input[1] = tmp_input2

    # save/restore_all need save/restore_shard as input
    # save/restore_shard needs save/Assign_[0-N] as input
    # save/Assign_[0-N] needs save/RestoreV2 as input
    if self._restore_shard_node:
      for tmp_input in drop_save_assigns:
        self._restore_shard_node.input.remove(tmp_input)
        if self._verbose:
          logging.info('drop restore_shard input: %s' % tmp_input)
    elif len(self._restore_all_node) > 0:
      for tmp_input in drop_save_assigns:
        for tmp_node in self._restore_all_node:
          if tmp_input in tmp_node.input:
            tmp_node.input.remove(tmp_input)
            if self._verbose:
              logging.info('drop %s input: %s' % (tmp_node.name, tmp_input))
              break