easy_rec/python/compat/embedding_parallel_saver.py (276 lines of code) (raw):
# -*- encoding:utf-8 -*-
import logging
import os
import numpy as np
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
# from tensorflow.python.ops import math_ops
# from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import gfile
from tensorflow.python.training import saver
from easy_rec.python.utils import constant
try:
import horovod.tensorflow as hvd
from sparse_operation_kit.experiment import raw_ops as dynamic_variable_ops
from easy_rec.python.compat import dynamic_variable
except Exception:
dynamic_variable_ops = None
dynamic_variable = None
try:
from tensorflow.python.framework.load_library import load_op_library
import easy_rec
load_embed_lib_path = os.path.join(easy_rec.ops_dir, 'libload_embed.so')
load_embed_lib = load_op_library(load_embed_lib_path)
except Exception as ex:
logging.warning('load libload_embed.so failed: %s' % str(ex))
load_embed_lib = None
def _get_embed_part_id(embed_file):
embed_file = embed_file.split('/')[-1]
embed_file = embed_file.split('.')[0]
embed_id = embed_file.split('-')[-1]
return int(embed_id)
class EmbeddingParallelSaver(saver.Saver):
def __init__(self,
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=saver_pb2.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None):
self._kv_vars = []
self._embed_vars = []
tf_vars = []
embed_para_vars = ops.get_collection(constant.EmbeddingParallel)
for var in var_list:
if dynamic_variable is not None and isinstance(
var, dynamic_variable.DynamicVariable):
self._kv_vars.append(var)
elif var.name in embed_para_vars:
logging.info('save shard embedding %s part_id=%d part_shape=%s' %
(var.name, hvd.rank(), var.get_shape()))
self._embed_vars.append(var)
else:
tf_vars.append(var)
super(EmbeddingParallelSaver, self).__init__(
tf_vars,
reshape=reshape,
sharded=sharded,
max_to_keep=max_to_keep,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
name=name,
restore_sequentially=restore_sequentially,
saver_def=saver_def,
builder=builder,
defer_build=defer_build,
allow_empty=allow_empty,
write_version=write_version,
pad_step_number=pad_step_number,
save_relative_paths=save_relative_paths,
filename=filename)
self._is_build = False
def _has_embed_vars(self):
return (len(self._kv_vars) + len(self._embed_vars)) > 0
def _save_dense_embedding(self, embed_var):
logging.info('task[%d] save_dense_embed: %s' % (hvd.rank(), embed_var.name))
def _save_embed(embed, filename, var_name):
task_id = hvd.rank()
filename = filename.decode('utf-8')
var_name = var_name.decode('utf-8').replace('/', '__')
embed_dir = filename + '-embedding/'
logging.info('task[%d] save_dense_embed: %s to %s' %
(task_id, var_name, embed_dir))
if not gfile.Exists(embed_dir):
gfile.MakeDirs(embed_dir)
embed_file = filename + '-embedding/embed-' + var_name + '-part-%d.bin' % task_id
with gfile.GFile(embed_file, 'wb') as fout:
fout.write(embed.tobytes())
if task_id == 0:
# clear old embedding tables
embed_pattern = filename + '-embedding/embed-' + var_name + '-part-*.bin'
embed_files = gfile.Glob(embed_pattern)
for embed_file in embed_files:
embed_id = _get_embed_part_id(embed_file)
if embed_id >= hvd.size():
gfile.DeleteRecursively(embed_file)
return np.asarray([embed_file], order='C', dtype=np.object)
file_name = ops.get_default_graph().get_tensor_by_name(
self.saver_def.filename_tensor_name)
save_paths = script_ops.py_func(_save_embed,
[embed_var, file_name, embed_var.name],
dtypes.string)
return save_paths
def _load_dense_embedding(self, embed_var):
file_name = ops.get_default_graph().get_tensor_by_name(
self.saver_def.filename_tensor_name)
embed_dim = embed_var.get_shape()[-1]
embed_part_size = embed_var.get_shape()[0]
def _load_embed(embed, embed_dim, embed_part_size, part_id, part_num,
filename, var_name):
filename = filename.decode('utf-8')
var_name = var_name.decode('utf-8').replace('/', '__')
embed_pattern = filename + '-embedding/embed-' + var_name + '-part-*.bin'
embed_files = gfile.Glob(embed_pattern)
embed_files.sort(key=_get_embed_part_id)
logging.info('task[%d] embed_files=%s embed_dim=%d embed_part_size=%d' %
(part_id, ','.join(embed_files), embed_dim, embed_part_size))
part_embed_vals = np.zeros([embed_part_size, embed_dim], dtype=np.float32)
part_update_cnt = 0
for embed_file in embed_files:
part_id_o = _get_embed_part_id(embed_file)
with gfile.GFile(embed_file, 'rb') as fin:
embed_val = np.frombuffer(fin.read(), np.float32)
embed_val = embed_val.reshape([-1, embed_dim])
embed_ids_o = np.arange(len(embed_val))
embed_ids_o = part_id_o + embed_ids_o * len(embed_files)
sel_ids = np.where(
np.logical_and((embed_ids_o % part_num) == part_id,
embed_ids_o < embed_part_size * part_num))[0]
part_update_cnt += len(sel_ids)
embed_ids = embed_ids_o[sel_ids]
embed_ids_n = np.array(embed_ids / part_num, dtype=np.int64)
part_embed_vals[embed_ids_n] = embed_val[sel_ids]
logging.info('task[%d] load_part_cnt=%d' % (part_id, part_update_cnt))
return part_embed_vals
with ops.control_dependencies([embed_var._initializer_op]):
if load_embed_lib is not None:
embed_val = load_embed_lib.load_embed(
task_index=hvd.rank(),
task_num=hvd.size(),
embed_dim=embed_dim,
embed_part_size=embed_part_size,
var_name='embed-' + embed_var.name.replace('/', '__'),
ckpt_path=file_name)
else:
embed_val = script_ops.py_func(_load_embed, [
embed_var, embed_dim, embed_part_size,
hvd.rank(),
hvd.size(), file_name, embed_var.name
], dtypes.float32)
embed_val.set_shape(embed_var.get_shape())
return state_ops.assign(embed_var, embed_val)
def _save_kv_embedding(self, sok_var):
indices, values = dynamic_variable_ops.dummy_var_export(
sok_var.handle, key_type=sok_var.key_type, dtype=sok_var.handle_dtype)
file_name = ops.get_default_graph().get_tensor_by_name(
self.saver_def.filename_tensor_name)
def _save_key_vals(indices, values, filename, var_name):
var_name = var_name.decode('utf-8').replace('/', '__')
filename = filename.decode('utf-8')
sok_dir = filename + '-embedding/'
if not gfile.Exists(sok_dir):
gfile.MakeDirs(sok_dir)
task_id = hvd.rank()
key_file = filename + '-embedding/embed-' + var_name + '-part-%d.key' % task_id
with gfile.GFile(key_file, 'wb') as fout:
fout.write(indices.tobytes())
val_file = filename + '-embedding/embed-' + var_name + '-part-%d.val' % task_id
with gfile.GFile(val_file, 'wb') as fout:
fout.write(values.tobytes())
if task_id == 0:
key_file_pattern = filename + '-embedding/embed-' + var_name + '-part-*.key'
key_files = gfile.Glob(key_file_pattern)
for key_file in key_files:
embed_id = _get_embed_part_id(key_file)
if embed_id >= hvd.size():
gfile.DeleteRecursively(key_file)
val_file = key_file[:-4] + '.val'
if gfile.Exists(val_file):
gfile.DeleteRecursively(val_file)
return np.asarray([key_file, val_file], order='C', dtype=np.object)
save_paths = script_ops.py_func(_save_key_vals,
[indices, values, file_name, sok_var.name],
dtypes.string)
return save_paths
def _load_kv_embedding(self, sok_var):
def _load_key_vals(filename, var_name):
var_name = var_name.decode('utf-8').replace('/', '__')
filename = filename.decode('utf-8')
key_file_pattern = filename + '-embedding/embed-' + var_name + '-part-*.key'
logging.info('key_file_pattern=%s filename=%s var_name=%s var=%s' %
(key_file_pattern, filename, var_name, str(sok_var)))
key_files = gfile.Glob(key_file_pattern)
logging.info('key_file_pattern=%s file_num=%d' %
(key_file_pattern, len(key_files)))
all_keys = []
all_vals = []
for key_file in key_files:
with gfile.GFile(key_file, 'rb') as fin:
tmp_keys = np.frombuffer(fin.read(), dtype=np.int64)
tmp_ids = tmp_keys % hvd.size()
tmp_ids = np.where(tmp_ids == hvd.rank())[0]
if len(tmp_ids) == 0:
break
all_keys.append(tmp_keys.take(tmp_ids, axis=0))
logging.info('part_keys.shape=%s %s %s' % (str(
tmp_keys.shape), str(tmp_ids.shape), str(all_keys[-1].shape)))
val_file = key_file[:-4] + 'vals'
with gfile.GFile(val_file, 'rb') as fin:
tmp_vals = np.frombuffer(
fin.read(), dtype=np.float32).reshape([-1, sok_var._dimension])
all_vals.append(tmp_vals.take(tmp_ids, axis=0))
logging.info('part_vals.shape=%s %s %s' % (str(
tmp_vals.shape), str(tmp_ids.shape), str(all_vals[-1].shape)))
all_keys = np.concatenate(all_keys, axis=0)
all_vals = np.concatenate(all_vals, axis=0)
shuffle_ids = np.array(range(len(all_keys)))
np.random.shuffle(shuffle_ids)
all_keys = all_keys.take(shuffle_ids, axis=0)
all_vals = all_vals.take(shuffle_ids, axis=0)
return all_keys, all_vals
file_name = ops.get_default_graph().get_tensor_by_name(
self.saver_def.filename_tensor_name)
if load_embed_lib is not None:
keys, vals = load_embed_lib.load_kv_embed(
task_index=hvd.rank(),
task_num=hvd.size(),
embed_dim=sok_var._dimension,
var_name='embed-' + sok_var.name.replace('/', '__'),
ckpt_path=file_name)
else:
logging.warning('libload_embed.so not loaded, will use python script_ops')
keys, vals = script_ops.py_func(_load_key_vals, [file_name, sok_var.name],
(dtypes.int64, dtypes.float32))
with ops.control_dependencies([sok_var._initializer_op]):
return dynamic_variable_ops.dummy_var_assign(sok_var.handle, keys, vals)
def build(self):
if self._is_built:
return
super(EmbeddingParallelSaver, self).build()
if self.saver_def.restore_op_name and self._has_embed_vars():
# load data from the model
restore_ops = []
for sok_var in self._kv_vars:
restore_ops.append(self._load_kv_embedding(sok_var))
for embed_var in self._embed_vars:
restore_ops.append(self._load_dense_embedding(embed_var))
old_restore_op = ops.get_default_graph().get_operation_by_name(
self.saver_def.restore_op_name)
restore_ops.append(old_restore_op)
restore_op_n = control_flow_ops.group(restore_ops)
self.saver_def.restore_op_name = restore_op_n.name
if self.saver_def.save_tensor_name and self._has_embed_vars():
file_name = ops.get_default_graph().get_tensor_by_name(
self.saver_def.filename_tensor_name)
save_part_ops = []
for sok_var in self._kv_vars:
save_part_op = self._save_kv_embedding(sok_var)
save_part_ops.append(save_part_op)
for embed_var in self._embed_vars:
save_part_op = self._save_dense_embedding(embed_var)
save_part_ops.append(save_part_op)
old_save_op = ops.get_default_graph().get_tensor_by_name(
self.saver_def.save_tensor_name)
# only the first worker needs to save non embedding variables
if hvd.rank() == 0:
save_part_ops.append(old_save_op)
with ops.control_dependencies(save_part_ops):
save_op_n = array_ops.identity(file_name)
self.saver_def.save_tensor_name = save_op_n.name