easy_rec/python/utils/meta_graph_editor.py (722 lines of code) (raw):

# -*- encoding:utf-8 -*- import logging import os import numpy as np import tensorflow as tf from google.protobuf import text_format from tensorflow.python.framework import ops from tensorflow.python.platform.gfile import GFile # from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model.loader_impl import SavedModelLoader from easy_rec.python.utils import conditional from easy_rec.python.utils import constant from easy_rec.python.utils import embedding_utils from easy_rec.python.utils import proto_util EMBEDDING_INITIALIZERS = 'embedding_initializers' class MetaGraphEditor: def __init__(self, lookup_lib_path, saved_model_dir, redis_url=None, redis_passwd=None, redis_timeout=0, redis_cache_names=[], oss_path=None, oss_endpoint=None, oss_ak=None, oss_sk=None, oss_timeout=0, meta_graph_def=None, norm_name_to_ids=None, incr_update_params=None, debug_dir=''): self._lookup_op = tf.load_op_library(lookup_lib_path) self._debug_dir = debug_dir self._verbose = debug_dir != '' if saved_model_dir: tags = ['serve'] loader = SavedModelLoader(saved_model_dir) saver, _ = loader.load_graph(tf.get_default_graph(), tags, None) meta_graph_def = loader.get_meta_graph_def_from_tags(tags) else: assert meta_graph_def, 'either saved_model_dir or meta_graph_def must be set' tf.reset_default_graph() from tensorflow.python.framework import meta_graph meta_graph.import_scoped_meta_graph_with_return_elements( meta_graph_def, clear_devices=True) # tf.train.import_meta_graph(meta_graph_def) self._meta_graph_version = meta_graph_def.meta_info_def.meta_graph_version self._signature_def = meta_graph_def.signature_def[ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] if self._verbose: debug_out_path = os.path.join(self._debug_dir, 'meta_graph_raw.txt') with GFile(debug_out_path, 'w') as fout: fout.write(text_format.MessageToString(meta_graph_def, as_utf8=True)) self._meta_graph_def = meta_graph_def self._old_node_num = len(self._meta_graph_def.graph_def.node) self._all_graph_nodes = None self._all_graph_node_flags = None self._restore_tensor_node = None self._restore_shard_node = None self._restore_all_node = [] self._lookup_outs = None self._feature_names = None self._embed_names = None self._embed_name_to_ids = norm_name_to_ids self._is_cache_from_redis = [] self._redis_cache_names = redis_cache_names self._embed_ids = None self._embed_dims = None self._embed_sizes = None self._embed_combiners = None self._redis_url = redis_url self._redis_passwd = redis_passwd self._redis_timeout = redis_timeout self._oss_path = oss_path self._oss_endpoint = oss_endpoint self._oss_ak = oss_ak self._oss_sk = oss_sk self._oss_timeout = oss_timeout self._incr_update_params = incr_update_params # increment update placeholders self._embedding_update_inputs = {} self._embedding_update_outputs = {} self._dense_update_inputs = {} self._dense_update_outputs = {} @property def sparse_update_inputs(self): return self._embedding_update_inputs @property def sparse_update_outputs(self): return self._embedding_update_outputs @property def dense_update_inputs(self): return self._dense_update_inputs @property def dense_update_outputs(self): return self._dense_update_outputs @property def graph_def(self): return self._meta_graph_def.graph_def @property def signature_def(self): return self._signature_def @property def meta_graph_version(self): return self._meta_graph_version def init_graph_node_clear_flags(self): graph_def = self._meta_graph_def.graph_def self._all_graph_nodes = [n for n in graph_def.node] self._all_graph_node_flags = [True for n in graph_def.node] def _get_share_embed_name(self, x, embed_names): """Map share embedding tensor names to embed names. Args: x: string, embedding tensor names, such as: input_layer_1/shared_embed_1/field16_shared_embedding input_layer_1/shared_embed_2/field17_shared_embedding input_layer/shared_embed_wide/field15_shared_embedding input_layer/shared_embed_wide_1/field16_shared_embedding embed_names: all the optional embedding_names Return: one element in embed_names, such as: input_layer_1/shared_embed input_layer_1/shared_embed input_layer/shared_embed_wide input_layer/shared_embed_wide """ assert x.endswith('_shared_embedding') name_toks = x.split('/') name_toks = name_toks[:-1] tmp = name_toks[-1] tmp = tmp.split('_') try: int(tmp[-1]) name_toks[-1] = '_'.join(tmp[:-1]) except Exception: pass tmp_name = '/'.join(name_toks[1:]) sel_embed_name = '' for embed_name in embed_names: tmp_toks = embed_name.split('/') tmp_toks = tmp_toks[1:] embed_name_sub = '/'.join(tmp_toks) if tmp_name == embed_name_sub: assert not sel_embed_name, 'confusions encountered: %s %s' % ( x, ','.join(embed_names)) sel_embed_name = embed_name assert sel_embed_name, '%s not find in shared_embeddings: %s' % ( tmp_name, ','.join(embed_names)) return sel_embed_name def _find_embed_combiners(self, norm_embed_names): """Find embedding lookup combiner methods. Args: norm_embed_names: normalized embedding names Return: list: combiner methods for each features: sum, mean, sqrtn """ embed_combiners = {} embed_combine_node_cts = {} combiner_map = { 'SparseSegmentSum': 'sum', 'SparseSegmentMean': 'mean', 'SparseSegmentSqrtN': 'sqrtn' } for node in self._meta_graph_def.graph_def.node: if node.op in combiner_map: norm_name, _ = proto_util.get_norm_embed_name(node.name) embed_combiners[norm_name] = combiner_map[node.op] embed_combine_node_cts[norm_name] = embed_combine_node_cts.get( norm_name, 0) + 1 elif node.op == 'RealDiv' and len(node.input) == 2: # for tag feature with weights, and combiner == mean if 'SegmentSum' in node.input[0] and 'SegmentSum' in node.input[1]: norm_name, _ = proto_util.get_norm_embed_name(node.name) embed_combiners[norm_name] = 'mean' embed_combine_node_cts[norm_name] = embed_combine_node_cts.get( norm_name, 0) + 1 elif node.op == 'SegmentSum': norm_name, _ = proto_util.get_norm_embed_name(node.name) # avoid overwrite RealDiv results if norm_name not in embed_combiners: embed_combiners[norm_name] = 'sum' embed_combine_node_cts[norm_name] = embed_combine_node_cts.get( norm_name, 0) + 1 return [embed_combiners[x] for x in norm_embed_names] def _find_lookup_indices_values_shapes(self): # use the specific _embedding_weights/SparseReshape to find out # lookup inputs: indices, values, dense_shape, weights indices = {} values = {} shapes = {} def _get_output_shape(graph_def, input_name): out_id = 0 if ':' in input_name: node_name, out_id = input_name.split(':') out_id = int(out_id) else: node_name = input_name for node in graph_def.node: if node.name == node_name: return node.attr['_output_shapes'].list.shape[out_id] return None for node in self._meta_graph_def.graph_def.node: if '_embedding_weights/SparseReshape' in node.name: if node.op == 'SparseReshape': # embed_name, _ = proto_util.get_norm_embed_name(node.name, self._verbose) fea_name, _ = proto_util.get_norm_embed_name(node.name, self._verbose) for tmp_input in node.input: tmp_shape = _get_output_shape(self._meta_graph_def.graph_def, tmp_input) if '_embedding_weights/Cast' in tmp_input: continue elif len(tmp_shape.dim) == 2: indices[fea_name] = tmp_input elif len(tmp_shape.dim) == 1: shapes[fea_name] = tmp_input elif node.op == 'Identity': fea_name, _ = proto_util.get_norm_embed_name(node.name, self._verbose) values[fea_name] = node.input[0] return indices, values, shapes def _find_lookup_weights(self): weights = {} for node in self._meta_graph_def.graph_def.node: if '_weighted_by_' in node.name and 'GatherV2' in node.name: has_sparse_reshape = False for tmp_input in node.input: if 'SparseReshape' in tmp_input: has_sparse_reshape = True if has_sparse_reshape: continue if len(node.input) != 3: continue # try to find nodes with weights # input_layer/xxx_weighted_by_yyy_embedding/xxx_weighted_by_yyy_embedding_weights/GatherV2_[0-9] # which has three inputs: # input_layer/xxx_weighted_by_yyy_embedding/xxx_weighted_by_yyy_embedding_weights/Reshape_1 # DeserializeSparse_1 (this is the weight) # input_layer/xxx_weighted_by_yyy_embedding/xxx_weighted_by_yyy_embedding_weights/GatherV2_4/axis fea_name, _ = proto_util.get_norm_embed_name(node.name, self._verbose) for tmp_input in node.input: if '_weighted_by_' not in tmp_input: weights[fea_name] = tmp_input return weights def _find_embed_names_and_dims(self, norm_embed_names): # get embedding dimensions from Variables embed_dims = {} embed_sizes = {} embed_is_kv = {} for node in self._meta_graph_def.graph_def.node: if 'embedding_weights' in node.name and node.op in [ 'VariableV2', 'KvVarHandleOp' ]: tmp = node.attr['shape'].shape.dim[-1].size tmp2 = 1 for x in node.attr['shape'].shape.dim[:-1]: tmp2 = tmp2 * x.size embed_name, _ = proto_util.get_norm_embed_name(node.name, self._verbose) assert embed_name is not None,\ 'fail to get_norm_embed_name(%s)' % node.name embed_dims[embed_name] = tmp embed_sizes[embed_name] = tmp2 embed_is_kv[embed_name] = 1 if node.op == 'KvVarHandleOp' else 0 # get all embedding dimensions, note that some embeddings # are shared by multiple inputs, so the names should be # transformed all_embed_dims = [] all_embed_names = [] all_embed_sizes = [] all_embed_is_kv = [] for x in norm_embed_names: if x in embed_dims: all_embed_names.append(x) all_embed_dims.append(embed_dims[x]) all_embed_sizes.append(embed_sizes[x]) all_embed_is_kv.append(embed_is_kv[x]) elif x.endswith('_shared_embedding'): tmp_embed_name = self._get_share_embed_name(x, embed_dims.keys()) all_embed_names.append(tmp_embed_name) all_embed_dims.append(embed_dims[tmp_embed_name]) all_embed_sizes.append(embed_sizes[tmp_embed_name]) all_embed_is_kv.append(embed_is_kv[tmp_embed_name]) return all_embed_names, all_embed_dims, all_embed_sizes, all_embed_is_kv def find_lookup_inputs(self): logging.info('Extract embedding_lookup inputs') indices, values, shapes = self._find_lookup_indices_values_shapes() weights = self._find_lookup_weights() for fea in shapes.keys(): logging.info('Lookup Input[%s]: indices=%s values=%s shapes=%s' % (fea, indices[fea], values[fea], shapes[fea])) graph = tf.get_default_graph() def _get_tensor_by_name(tensor_name): if ':' not in tensor_name: tensor_name = tensor_name + ':0' return graph.get_tensor_by_name(tensor_name) lookup_input_values = [] lookup_input_indices = [] lookup_input_shapes = [] lookup_input_weights = [] for key in values.keys(): tmp_val, tmp_ind, tmp_shape = values[key], indices[key], shapes[key] lookup_input_values.append(_get_tensor_by_name(tmp_val)) lookup_input_indices.append(_get_tensor_by_name(tmp_ind)) lookup_input_shapes.append(_get_tensor_by_name(tmp_shape)) if key in weights: tmp_w = weights[key] lookup_input_weights.append(_get_tensor_by_name(tmp_w)) else: lookup_input_weights.append([]) # get embedding combiners self._embed_combiners = self._find_embed_combiners(values.keys()) # get embedding dimensions self._embed_names, self._embed_dims, self._embed_sizes, self._embed_is_kv\ = self._find_embed_names_and_dims(values.keys()) if not self._embed_name_to_ids: embed_name_uniq = list(set(self._embed_names)) self._embed_name_to_ids = { t: tid for tid, t in enumerate(embed_name_uniq) } self._embed_ids = [ int(self._embed_name_to_ids[x]) for x in self._embed_names ] self._is_cache_from_redis = [ proto_util.is_cache_from_redis(x, self._redis_cache_names) for x in self._embed_names ] # normalized feature names self._feature_names = list(values.keys()) return lookup_input_indices, lookup_input_values, lookup_input_shapes,\ lookup_input_weights def add_lookup_op(self, lookup_input_indices, lookup_input_values, lookup_input_shapes, lookup_input_weights): logging.info('add custom lookup operation to lookup embeddings from redis') self._lookup_outs = [None for i in range(len(lookup_input_values))] for i in range(len(lookup_input_values)): if lookup_input_values[i].dtype == tf.int32: lookup_input_values[i] = tf.to_int64(lookup_input_values[i]) for i in range(len(self._lookup_outs)): i_1 = i + 1 self._lookup_outs[i] = self._lookup_op.kv_lookup( lookup_input_indices[i:i_1], lookup_input_values[i:i_1], lookup_input_shapes[i:i_1], lookup_input_weights[i:i_1], url=self._redis_url, password=self._redis_passwd, timeout=self._redis_timeout, combiners=self._embed_combiners[i:i_1], embedding_dims=self._embed_dims[i:i_1], embedding_names=self._embed_ids[i:i_1], cache=self._is_cache_from_redis, version=self._meta_graph_version)[0] meta_graph_def = tf.train.export_meta_graph() if self._verbose: debug_path = os.path.join(self._debug_dir, 'graph_raw.txt') with GFile(debug_path, 'w') as fout: fout.write( text_format.MessageToString( self._meta_graph_def.graph_def, as_utf8=True)) return meta_graph_def def add_oss_lookup_op(self, lookup_input_indices, lookup_input_values, lookup_input_shapes, lookup_input_weights): logging.info('add custom lookup operation to lookup embeddings from oss') place_on_cpu = os.getenv('place_embedding_on_cpu') place_on_cpu = eval(place_on_cpu) if place_on_cpu else False with conditional(place_on_cpu, ops.device('/CPU:0')): for i in range(len(lookup_input_values)): if lookup_input_values[i].dtype == tf.int32: lookup_input_values[i] = tf.to_int64(lookup_input_values[i]) # N = len(lookup_input_indices) # self._lookup_outs = [ None for _ in range(N) ] # for i in range(N): # i_1 = i + 1 # self._lookup_outs[i] = self._lookup_op.oss_read_kv( # lookup_input_indices[i:i_1], # lookup_input_values[i:i_1], # lookup_input_shapes[i:i_1], # lookup_input_weights[i:i_1], # osspath=self._oss_path, # endpoint=self._oss_endpoint, # ak=self._oss_ak, # sk=self._oss_sk, # timeout=self._oss_timeout, # combiners=self._embed_combiners[i:i_1], # embedding_dims=self._embed_dims[i:i_1], # embedding_ids=self._embed_ids[i:i_1], # embedding_is_kv=self._embed_is_kv[i:i_1], # shared_name='embedding_lookup_res', # name='embedding_lookup_fused/lookup')[0] self._lookup_outs = self._lookup_op.oss_read_kv( lookup_input_indices, lookup_input_values, lookup_input_shapes, lookup_input_weights, osspath=self._oss_path, endpoint=self._oss_endpoint, ak=self._oss_ak, sk=self._oss_sk, timeout=self._oss_timeout, combiners=self._embed_combiners, embedding_dims=self._embed_dims, embedding_ids=self._embed_ids, embedding_is_kv=self._embed_is_kv, shared_name='embedding_lookup_res', name='embedding_lookup_fused/lookup') N = np.max([int(x) for x in self._embed_ids]) + 1 uniq_embed_ids = [x for x in range(N)] uniq_embed_dims = [0 for x in range(N)] uniq_embed_combiners = ['mean' for x in range(N)] uniq_embed_is_kvs = [0 for x in range(N)] for embed_id, embed_combiner, embed_is_kv, embed_dim in zip( self._embed_ids, self._embed_combiners, self._embed_is_kv, self._embed_dims): uniq_embed_combiners[embed_id] = embed_combiner uniq_embed_is_kvs[embed_id] = embed_is_kv uniq_embed_dims[embed_id] = embed_dim lookup_init_op = self._lookup_op.oss_init( osspath=self._oss_path, endpoint=self._oss_endpoint, ak=self._oss_ak, sk=self._oss_sk, combiners=uniq_embed_combiners, embedding_dims=uniq_embed_dims, embedding_ids=uniq_embed_ids, embedding_is_kv=uniq_embed_is_kvs, N=N, shared_name='embedding_lookup_res', name='embedding_lookup_fused/init') ops.add_to_collection(EMBEDDING_INITIALIZERS, lookup_init_op) if self._incr_update_params is not None: # all sparse variables are updated by a single custom operation message_ph = tf.placeholder(tf.int8, [None], name='incr_update/message') embedding_update = self._lookup_op.embedding_update( message=message_ph, shared_name='embedding_lookup_res', name='embedding_lookup_fused/embedding_update') self._embedding_update_inputs['incr_update/sparse/message'] = message_ph self._embedding_update_outputs[ 'incr_update/sparse/embedding_update'] = embedding_update # dense variables are updated one by one dense_name_to_ids = embedding_utils.get_dense_name_to_ids() for x in ops.get_collection(constant.DENSE_UPDATE_VARIABLES): dense_var_id = dense_name_to_ids[x.op.name] dense_input_name = 'incr_update/dense/%d/input' % dense_var_id dense_output_name = 'incr_update/dense/%d/output' % dense_var_id dense_update_input = tf.placeholder( tf.float32, x.get_shape(), name=dense_input_name) self._dense_update_inputs[dense_input_name] = dense_update_input dense_assign_op = tf.assign(x, dense_update_input) self._dense_update_outputs[dense_output_name] = dense_assign_op meta_graph_def = tf.train.export_meta_graph() if self._verbose: debug_path = os.path.join(self._debug_dir, 'graph_raw.txt') with GFile(debug_path, 'w') as fout: fout.write( text_format.MessageToString( self._meta_graph_def.graph_def, as_utf8=True)) return meta_graph_def def bytes2str(self, x): if bytes == str: return x else: try: return x.decode('utf-8') except Exception: # in case of some special chars in protobuf return str(x) def clear_meta_graph_embeding(self, meta_graph_def): logging.info('clear meta graph embedding_weights') def _clear_embedding_in_meta_collect(meta_graph_def, collect_name): tmp_vals = [ x for x in meta_graph_def.collection_def[collect_name].bytes_list.value if 'embedding_weights' not in self.bytes2str(x) ] meta_graph_def.collection_def[collect_name].bytes_list.ClearField('value') for tmp_v in tmp_vals: meta_graph_def.collection_def[collect_name].bytes_list.value.append( tmp_v) _clear_embedding_in_meta_collect(meta_graph_def, 'model_variables') _clear_embedding_in_meta_collect(meta_graph_def, 'trainable_variables') _clear_embedding_in_meta_collect(meta_graph_def, 'variables') # clear Kv(pai embedding variable) ops in meta_info_def.stripped_op_list.op kept_ops = [ x for x in meta_graph_def.meta_info_def.stripped_op_list.op if x.name not in [ 'InitializeKvVariableOp', 'KvResourceGather', 'KvResourceImportV2', 'KvVarHandleOp', 'KvVarIsInitializedOp', 'ReadKvVariableOp' ] ] meta_graph_def.meta_info_def.stripped_op_list.ClearField('op') meta_graph_def.meta_info_def.stripped_op_list.op.extend(kept_ops) for tmp_op in meta_graph_def.meta_info_def.stripped_op_list.op: if tmp_op.name == 'SaveV2': for tmp_id, tmp_attr in enumerate(tmp_op.attr): if tmp_attr.name == 'has_ev': tmp_op.attr.remove(tmp_attr) break def clear_meta_collect(self, meta_graph_def): drop_meta_collects = [] for key in meta_graph_def.collection_def: val = meta_graph_def.collection_def[key] if val.HasField('node_list'): if 'embedding_weights' in val.node_list.value[ 0] and 'easy_rec' not in val.node_list.value[0]: drop_meta_collects.append(key) elif key == 'saved_model_assets': drop_meta_collects.append(key) for key in drop_meta_collects: meta_graph_def.collection_def.pop(key) def remove_embedding_weights_and_update_lookup_outputs(self): def _should_drop(name): if '_embedding_weights' in name: if self._verbose: logging.info('[SHOULD_DROP] %s' % name) return True logging.info('remove embedding_weights node in graph_def.node') logging.info( 'and replace the old embedding_lookup outputs with new lookup_op outputs' ) for tid, node in enumerate(self._all_graph_nodes): # drop the nodes if _should_drop(node.name): self._all_graph_node_flags[tid] = False else: for i in range(len(node.input)): if _should_drop(node.input[i]): input_name, _ = proto_util.get_norm_embed_name( node.input[i], self._verbose) print('REPLACE:' + node.input[i] + '=>' + input_name) input_name = self._lookup_outs[self._feature_names.index( input_name)].name if input_name.endswith(':0'): input_name = input_name.replace(':0', '') node.input[i] = input_name # drop by ids def _drop_by_ids(self, tmp_obj, key, drop_ids): keep_vals = [ x for i, x in enumerate(getattr(tmp_obj, key)) if i not in drop_ids ] tmp_obj.ClearField(key) getattr(tmp_obj, key).extend(keep_vals) def clear_save_restore(self): """Clear save restore ops. save/restore_all need save/restore_shard as input save/restore_shard needs save/Assign_[0-N] as input save/Assign_[0-N] needs save/RestoreV2 as input save/RestoreV2 use save/RestoreV2/tensor_names and save/RestoreV2/shape_and_slices as input edit [ save/RestoreV2/tensor_names save/RestoreV2/shape_and_slices save/RestoreV2 save/restore_shard ] """ for tid, node in enumerate(self._all_graph_nodes): if not self._all_graph_node_flags[tid]: continue if node.name == 'save/RestoreV2/tensor_names': self._restore_tensor_node = node break # assert self._restore_tensor_node is not None, 'save/RestoreV2/tensor_names is not found' if self._restore_tensor_node: drop_ids = [] for tmp_id, tmp_name in enumerate( self._restore_tensor_node.attr['value'].tensor.string_val): if 'embedding_weights' in self.bytes2str(tmp_name): drop_ids.append(tmp_id) self._drop_by_ids(self._restore_tensor_node.attr['value'].tensor, 'string_val', drop_ids) keep_node_num = len( self._restore_tensor_node.attr['value'].tensor.string_val) logging.info( 'update self._restore_tensor_node: string_val keep_num = %d drop_num = %d' % (keep_node_num, len(drop_ids))) self._restore_tensor_node.attr['value'].tensor.tensor_shape.dim[ 0].size = keep_node_num self._restore_tensor_node.attr['_output_shapes'].list.shape[0].dim[ 0].size = keep_node_num logging.info( 'update save/RestoreV2, drop tensor_shapes, _output_shapes, related to embedding_weights' ) self._restore_shard_node = None for node_id, node in enumerate(self._all_graph_nodes): if not self._all_graph_node_flags[tid]: continue if node.name == 'save/RestoreV2/shape_and_slices': node.attr['value'].tensor.tensor_shape.dim[0].size = keep_node_num node.attr['_output_shapes'].list.shape[0].dim[0].size = keep_node_num self._drop_by_ids(node.attr['value'].tensor, 'string_val', drop_ids) elif node.name == 'save/RestoreV2': self._drop_by_ids(node.attr['_output_shapes'].list, 'shape', drop_ids) self._drop_by_ids(node.attr['dtypes'].list, 'type', drop_ids) elif node.name == 'save/restore_shard': self._restore_shard_node = node elif node.name.startswith('save/restore_all'): self._restore_all_node.append(node) def clear_save_assign(self): logging.info( 'update save/Assign, drop tensor_shapes, _output_shapes, related to embedding_weights' ) # edit save/Assign drop_save_assigns = [] all_kv_drop = [] for tid, node in enumerate(self._all_graph_nodes): if not self._all_graph_node_flags[tid]: continue if node.op == 'Assign' and 'save/Assign' in node.name and \ 'embedding_weights' in node.input[0]: drop_save_assigns.append('^' + node.name) self._all_graph_node_flags[tid] = False elif 'embedding_weights/ConcatPartitions/concat' in node.name: self._all_graph_node_flags[tid] = False elif node.name.endswith('/embedding_weights') and node.op == 'Identity': self._all_graph_node_flags[tid] = False elif 'save/KvResourceImportV2' in node.name and node.op == 'KvResourceImportV2': drop_save_assigns.append('^' + node.name) self._all_graph_node_flags[tid] = False elif 'KvResourceImportV2' in node.name: self._all_graph_node_flags[tid] = False elif 'save/Const' in node.name and node.op == 'Const': if '_class' in node.attr and len(node.attr['_class'].list.s) > 0: const_name = node.attr['_class'].list.s[0] if not isinstance(const_name, str): const_name = const_name.decode('utf-8') if 'embedding_weights' in const_name: self._all_graph_node_flags[tid] = False elif 'ReadKvVariableOp' in node.name and node.op == 'ReadKvVariableOp': all_kv_drop.append(node.name) self._all_graph_node_flags[tid] = False elif node.op == 'Assign' and 'save/Assign' in node.name: # update node(save/Assign_[0-N])'s input[1] by the position of # node.input[0] in save/RestoreV2/tensor_names # the outputs of save/RestoreV2 is connected to save/Assign tmp_id = [ self.bytes2str(x) for x in self._restore_tensor_node.attr['value'].tensor.string_val ].index(node.input[0]) if tmp_id != 0: tmp_input2 = 'save/RestoreV2:%d' % tmp_id else: tmp_input2 = 'save/RestoreV2' if tmp_input2 != node.input[1]: if self._verbose: logging.info("update save/Assign[%s]'s input from %s to %s" % (node.name, node.input[1], tmp_input2)) node.input[1] = tmp_input2 # save/restore_all need save/restore_shard as input # save/restore_shard needs save/Assign_[0-N] as input # save/Assign_[0-N] needs save/RestoreV2 as input if self._restore_shard_node: for tmp_input in drop_save_assigns: self._restore_shard_node.input.remove(tmp_input) if self._verbose: logging.info('drop restore_shard input: %s' % tmp_input) elif len(self._restore_all_node) > 0: for tmp_input in drop_save_assigns: for tmp_node in self._restore_all_node: if tmp_input in tmp_node.input: tmp_node.input.remove(tmp_input) if self._verbose: logging.info('drop %s input: %s' % (tmp_node.name, tmp_input)) break def clear_save_v2(self): """Clear SaveV2 ops. save/Identity need [ save/MergeV2Checkpoints, save/control_dependency ] as input. Save/MergeV2Checkpoints need [save/MergeV2Checkpoints/checkpoint_prefixes] as input. Save/MergeV2Checkpoints/checkpoint_prefixes need [ save/ShardedFilename, save/control_dependency ] as input. save/control_dependency need save/SaveV2 as input. save/SaveV2 input: [ save/SaveV2/tensor_names, save/SaveV2/shape_and_slices ] edit save/SaveV2 save/SaveV2/shape_and_slices save/SaveV2/tensor_names. """ logging.info('update save/SaveV2 input shape, _output_shapes, tensor_shape') save_drop_ids = [] for tid, node in enumerate(self._all_graph_nodes): if not self._all_graph_node_flags[tid]: continue if node.name == 'save/SaveV2' and node.op == 'SaveV2': for tmp_id, tmp_input in enumerate(node.input): if '/embedding_weights' in tmp_input: save_drop_ids.append(tmp_id) diff_num = len(node.input) - len(node.attr['dtypes'].list.type) self._drop_by_ids(node, 'input', save_drop_ids) save_drop_ids = [x - diff_num for x in save_drop_ids] self._drop_by_ids(node.attr['dtypes'].list, 'type', save_drop_ids) if 'has_ev' in node.attr: del node.attr['has_ev'] for node in self._all_graph_nodes: if node.name == 'save/SaveV2/shape_and_slices' and node.op == 'Const': # _output_shapes # size # string_val node.attr['_output_shapes'].list.shape[0].dim[0].size -= len( save_drop_ids) node.attr['value'].tensor.tensor_shape.dim[0].size -= len(save_drop_ids) self._drop_by_ids(node.attr['value'].tensor, 'string_val', save_drop_ids) elif node.name == 'save/SaveV2/tensor_names': # tensor_names may not have the same order as save/SaveV2/shape_and_slices tmp_drop_ids = [ tmp_id for tmp_id, tmp_val in enumerate( node.attr['value'].tensor.string_val) if 'embedding_weights' in self.bytes2str(tmp_val) ] # attr['value'].tensor.string_val # tensor_shape # size assert len(save_drop_ids) == len(save_drop_ids) node.attr['_output_shapes'].list.shape[0].dim[0].size -= len( tmp_drop_ids) node.attr['value'].tensor.tensor_shape.dim[0].size -= len(tmp_drop_ids) self._drop_by_ids(node.attr['value'].tensor, 'string_val', tmp_drop_ids) def clear_initialize(self): """Clear initialization ops. */read(Identity) depend on [*(VariableV2)] */Assign depend on [*/Initializer/*, *(VariableV2)] drop embedding_weights initialization nodes */embedding_weights/part_x [,/Assign,/read] */embedding_weights/part_1/Initializer/truncated_normal [,/shape,/mean,/stddev,/TruncatedNormal,/mul] """ logging.info('Remove Initialization nodes for embedding_weights') for tid, node in enumerate(self._all_graph_nodes): if not self._all_graph_node_flags[tid]: continue if 'embedding_weights' in node.name and 'Initializer' in node.name: self._all_graph_node_flags[tid] = False elif 'embedding_weights' in node.name and 'Assign' in node.name: self._all_graph_node_flags[tid] = False elif 'embedding_weights' in node.name and node.op == 'VariableV2': self._all_graph_node_flags[tid] = False elif 'embedding_weights' in node.name and node.name.endswith( '/read') and node.op == 'Identity': self._all_graph_node_flags[tid] = False elif 'embedding_weights' in node.name and node.op == 'Identity': node_toks = node.name.split('/') node_tok = node_toks[-1] if 'embedding_weights_' in node_tok: node_tok = node_tok[len('embedding_weights_'):] try: int(node_tok) self._all_graph_node_flags[tid] = False except Exception: pass def clear_embedding_variable(self): # for pai embedding variable, we drop some special nodes for tid, node in enumerate(self._all_graph_nodes): if not self._all_graph_node_flags[tid]: continue if node.op in [ 'ReadKvVariableOp', 'KvVarIsInitializedOp', 'KvVarHandleOp' ]: self._all_graph_node_flags[tid] = False # there maybe some nodes depend on the dropped nodes, they are dropped as well def drop_dependent_nodes(self): drop_names = [ tmp_node.name for tid, tmp_node in enumerate(self._all_graph_nodes) if not self._all_graph_node_flags[tid] ] while True: more_drop_names = [] for tid, tmp_node in enumerate(self._all_graph_nodes): if not self._all_graph_node_flags[tid]: continue if len(tmp_node.input) > 0 and tmp_node.input[0] in drop_names: logging.info('drop dependent node: %s depend on %s' % (tmp_node.name, tmp_node.input[0])) self._all_graph_node_flags[tid] = False more_drop_names.append(tmp_node.name) drop_names = more_drop_names if not drop_names: break def edit_graph(self): # the main entrance lookup_input_indices, lookup_input_values, lookup_input_shapes,\ lookup_input_weights = self.find_lookup_inputs() # add lookup op to the graph self._meta_graph_def = self.add_lookup_op(lookup_input_indices, lookup_input_values, lookup_input_shapes, lookup_input_weights) self.clear_meta_graph_embeding(self._meta_graph_def) self.clear_meta_collect(self._meta_graph_def) self.init_graph_node_clear_flags() self.remove_embedding_weights_and_update_lookup_outputs() # save/RestoreV2 self.clear_save_restore() # save/Assign self.clear_save_assign() # save/SaveV2 self.clear_save_v2() self.clear_initialize() self.clear_embedding_variable() self.drop_dependent_nodes() self._meta_graph_def.graph_def.ClearField('node') self._meta_graph_def.graph_def.node.extend([ x for tid, x in enumerate(self._all_graph_nodes) if self._all_graph_node_flags[tid] ]) logging.info('old node number = %d' % self._old_node_num) logging.info('node number = %d' % len(self._meta_graph_def.graph_def.node)) if self._verbose: debug_dump_path = os.path.join(self._debug_dir, 'graph.txt') with GFile(debug_dump_path, 'w') as fout: fout.write(text_format.MessageToString(self.graph_def, as_utf8=True)) debug_dump_path = os.path.join(self._debug_dir, 'meta_graph.txt') with GFile(debug_dump_path, 'w') as fout: fout.write( text_format.MessageToString(self._meta_graph_def, as_utf8=True)) def edit_graph_for_oss(self): # the main entrance lookup_input_indices, lookup_input_values, lookup_input_shapes,\ lookup_input_weights = self.find_lookup_inputs() # add lookup op to the graph self._meta_graph_def = self.add_oss_lookup_op(lookup_input_indices, lookup_input_values, lookup_input_shapes, lookup_input_weights) self.clear_meta_graph_embeding(self._meta_graph_def) self.clear_meta_collect(self._meta_graph_def) self.init_graph_node_clear_flags() self.remove_embedding_weights_and_update_lookup_outputs() # save/RestoreV2 self.clear_save_restore() # save/Assign self.clear_save_assign() # save/SaveV2 self.clear_save_v2() self.clear_initialize() self.clear_embedding_variable() self.drop_dependent_nodes() self._meta_graph_def.graph_def.ClearField('node') self._meta_graph_def.graph_def.node.extend([ x for tid, x in enumerate(self._all_graph_nodes) if self._all_graph_node_flags[tid] ]) logging.info('old node number = %d' % self._old_node_num) logging.info('node number = %d' % len(self._meta_graph_def.graph_def.node)) if self._verbose: debug_dump_path = os.path.join(self._debug_dir, 'graph.txt') with GFile(debug_dump_path, 'w') as fout: fout.write(text_format.MessageToString(self.graph_def, as_utf8=True)) debug_dump_path = os.path.join(self._debug_dir, 'meta_graph.txt') with GFile(debug_dump_path, 'w') as fout: fout.write( text_format.MessageToString(self._meta_graph_def, as_utf8=True))