def get_norm_embed_name()

in easy_rec/python/utils/proto_util.py [0:0]


def get_norm_embed_name(name, verbose=False):
  """For embedding export to redis.

  Args:
    name: variable name
    verbose: whether to dump the embed_names
  Return:
    embedding_name: normalized embedding_name
    embedding_part_id: normalized embedding part_id
    if embedding_weights not in name, return None, None
  """
  name_toks = name.split('/')
  for i in range(0, len(name_toks) - 1):
    if name_toks[i + 1].startswith('embedding_weights:'):
      var_id = name_toks[i + 1].replace('embedding_weights:', '')
      tmp_name = '/'.join(name_toks[:i + 1])
      if var_id != '0':
        tmp_name = tmp_name + '_' + var_id
      if verbose:
        logging.info('norm %s to %s' % (name, tmp_name))
      return tmp_name, 0
    if i > 1 and name_toks[i + 1].startswith('part_') and \
       name_toks[i] == 'embedding_weights':
      tmp_name = '/'.join(name_toks[:i])
      part_id = name_toks[i + 1].replace('part_', '')
      part_toks = part_id.split(':')
      if len(part_toks) >= 2 and part_toks[1] != '0':
        tmp_name = tmp_name + '_' + part_toks[1]
      if verbose:
        logging.info('norm %s to %s' % (name, tmp_name))
      return tmp_name, int(part_toks[0])

  # input_layer/app_category_embedding/app_category_embedding_weights/SparseReshape
  # => input_layer/app_category_embedding
  for i in range(0, len(name_toks) - 1):
    if name_toks[i + 1].endswith('_embedding_weights') or \
       '_embedding_weights_' in name_toks[i + 1]:
      tmp_name = '/'.join(name_toks[:i + 1])
      if verbose:
        logging.info('norm %s to %s' % (name, tmp_name))
      return tmp_name, 0
  # input_layer/app_category_embedding/embedding_weights
  # => input_layer/app_category_embedding
  for i in range(0, len(name_toks) - 1):
    if name_toks[i + 1] == 'embedding_weights':
      tmp_name = '/'.join(name_toks[:i + 1])
      if verbose:
        logging.info('norm %s to %s' % (name, tmp_name))
      return tmp_name, 0
  logging.warning('Failed to norm: %s' % name)
  return None, None