in easy_rec/python/utils/meta_graph_editor.py [0:0]
def clear_save_assign(self):
logging.info(
'update save/Assign, drop tensor_shapes, _output_shapes, related to embedding_weights'
)
# edit save/Assign
drop_save_assigns = []
all_kv_drop = []
for tid, node in enumerate(self._all_graph_nodes):
if not self._all_graph_node_flags[tid]:
continue
if node.op == 'Assign' and 'save/Assign' in node.name and \
'embedding_weights' in node.input[0]:
drop_save_assigns.append('^' + node.name)
self._all_graph_node_flags[tid] = False
elif 'embedding_weights/ConcatPartitions/concat' in node.name:
self._all_graph_node_flags[tid] = False
elif node.name.endswith('/embedding_weights') and node.op == 'Identity':
self._all_graph_node_flags[tid] = False
elif 'save/KvResourceImportV2' in node.name and node.op == 'KvResourceImportV2':
drop_save_assigns.append('^' + node.name)
self._all_graph_node_flags[tid] = False
elif 'KvResourceImportV2' in node.name:
self._all_graph_node_flags[tid] = False
elif 'save/Const' in node.name and node.op == 'Const':
if '_class' in node.attr and len(node.attr['_class'].list.s) > 0:
const_name = node.attr['_class'].list.s[0]
if not isinstance(const_name, str):
const_name = const_name.decode('utf-8')
if 'embedding_weights' in const_name:
self._all_graph_node_flags[tid] = False
elif 'ReadKvVariableOp' in node.name and node.op == 'ReadKvVariableOp':
all_kv_drop.append(node.name)
self._all_graph_node_flags[tid] = False
elif node.op == 'Assign' and 'save/Assign' in node.name:
# update node(save/Assign_[0-N])'s input[1] by the position of
# node.input[0] in save/RestoreV2/tensor_names
# the outputs of save/RestoreV2 is connected to save/Assign
tmp_id = [
self.bytes2str(x)
for x in self._restore_tensor_node.attr['value'].tensor.string_val
].index(node.input[0])
if tmp_id != 0:
tmp_input2 = 'save/RestoreV2:%d' % tmp_id
else:
tmp_input2 = 'save/RestoreV2'
if tmp_input2 != node.input[1]:
if self._verbose:
logging.info("update save/Assign[%s]'s input from %s to %s" %
(node.name, node.input[1], tmp_input2))
node.input[1] = tmp_input2
# save/restore_all need save/restore_shard as input
# save/restore_shard needs save/Assign_[0-N] as input
# save/Assign_[0-N] needs save/RestoreV2 as input
if self._restore_shard_node:
for tmp_input in drop_save_assigns:
self._restore_shard_node.input.remove(tmp_input)
if self._verbose:
logging.info('drop restore_shard input: %s' % tmp_input)
elif len(self._restore_all_node) > 0:
for tmp_input in drop_save_assigns:
for tmp_node in self._restore_all_node:
if tmp_input in tmp_node.input:
tmp_node.input.remove(tmp_input)
if self._verbose:
logging.info('drop %s input: %s' % (tmp_node.name, tmp_input))
break