in easy_rec/python/utils/estimator_utils.py [0:0]
def __init__(self,
checkpoint_dir,
save_secs=None,
save_steps=None,
saver=None,
checkpoint_basename='model.ckpt',
scaffold=None,
listeners=None,
write_graph=True,
data_offset_var=None,
increment_save_config=None):
"""Initializes a `CheckpointSaverHook`.
Args:
checkpoint_dir: `str`, base directory for the checkpoint files.
save_secs: `int`, save every N secs.
save_steps: `int`, save every N steps.
saver: `Saver` object, used for saving.
checkpoint_basename: `str`, base name for the checkpoint files.
scaffold: `Scaffold`, use to get saver object.
listeners: List of `CheckpointSaverListener` subclass instances.
Used for callbacks that run immediately before or after this hook saves
the checkpoint.
write_graph: whether to save graph.pbtxt.
data_offset_var: data offset variable.
increment_save_config: parameters for saving increment checkpoints.
Raises:
ValueError: One of `save_steps` or `save_secs` should be set.
ValueError: At most one of saver or scaffold should be set.
"""
super(CheckpointSaverHook, self).__init__(
checkpoint_dir,
save_secs=save_secs,
save_steps=save_steps,
saver=saver,
checkpoint_basename=checkpoint_basename,
scaffold=scaffold,
listeners=listeners)
self._cuda_profile_start = 0
self._cuda_profile_stop = 0
self._steps_per_run = 1
self._write_graph = write_graph
self._data_offset_var = data_offset_var
self._task_idx, self._task_num = get_task_index_and_num()
if increment_save_config is not None:
self._kafka_timeout_ms = os.environ.get('KAFKA_TIMEOUT', 600) * 1000
logging.info('KAFKA_TIMEOUT: %dms' % self._kafka_timeout_ms)
self._kafka_max_req_size = os.environ.get('KAFKA_MAX_REQ_SIZE',
1024 * 1024 * 64)
logging.info('KAFKA_MAX_REQ_SIZE: %d' % self._kafka_max_req_size)
self._kafka_max_msg_size = os.environ.get('KAFKA_MAX_MSG_SIZE',
1024 * 1024 * 1024)
logging.info('KAFKA_MAX_MSG_SIZE: %d' % self._kafka_max_msg_size)
self._dense_name_to_ids = embedding_utils.get_dense_name_to_ids()
self._sparse_name_to_ids = embedding_utils.get_sparse_name_to_ids()
with gfile.GFile(
os.path.join(checkpoint_dir, constant.DENSE_UPDATE_VARIABLES),
'w') as fout:
json.dump(self._dense_name_to_ids, fout, indent=2)
save_secs = increment_save_config.dense_save_secs
save_steps = increment_save_config.dense_save_steps
self._dense_timer = SecondOrStepTimer(
every_secs=save_secs if save_secs > 0 else None,
every_steps=save_steps if save_steps > 0 else None)
save_secs = increment_save_config.sparse_save_secs
save_steps = increment_save_config.sparse_save_steps
self._sparse_timer = SecondOrStepTimer(
every_secs=save_secs if save_secs > 0 else None,
every_steps=save_steps if save_steps > 0 else None)
self._dense_timer.update_last_triggered_step(0)
self._sparse_timer.update_last_triggered_step(0)
self._sparse_indices = []
self._sparse_values = []
sparse_train_vars = ops.get_collection(constant.SPARSE_UPDATE_VARIABLES)
for sparse_var, indice_dtype in sparse_train_vars:
with ops.control_dependencies([tf.train.get_global_step()]):
with ops.colocate_with(sparse_var):
sparse_indice = get_sparse_indices(
var_name=sparse_var.op.name, ktype=indice_dtype)
# sparse_indice = sparse_indice.global_indices
self._sparse_indices.append(sparse_indice)
if 'EmbeddingVariable' in str(type(sparse_var)):
self._sparse_values.append(
kv_resource_incr_gather(
sparse_var._handle, sparse_indice,
np.zeros(sparse_var.shape.as_list(), dtype=np.float32)))
# sparse_var.sparse_read(sparse_indice))
else:
self._sparse_values.append(
array_ops.gather(sparse_var, sparse_indice))
self._kafka_producer = None
self._incr_save_dir = None
if increment_save_config.HasField('kafka'):
self._topic = increment_save_config.kafka.topic
logging.info('increment save topic: %s' % self._topic)
admin_clt = KafkaAdminClient(
bootstrap_servers=increment_save_config.kafka.server,
request_timeout_ms=self._kafka_timeout_ms,
api_version_auto_timeout_ms=self._kafka_timeout_ms)
if self._topic not in admin_clt.list_topics():
admin_clt.create_topics(
new_topics=[
NewTopic(
name=self._topic,
num_partitions=1,
replication_factor=1,
topic_configs={
'max.message.bytes': self._kafka_max_msg_size
})
],
validate_only=False)
logging.info('create increment save topic: %s' % self._topic)
admin_clt.close()
servers = increment_save_config.kafka.server.split(',')
self._kafka_producer = KafkaProducer(
bootstrap_servers=servers,
max_request_size=self._kafka_max_req_size,
api_version_auto_timeout_ms=self._kafka_timeout_ms,
request_timeout_ms=self._kafka_timeout_ms)
elif increment_save_config.HasField('fs'):
fs = increment_save_config.fs
if fs.relative:
self._incr_save_dir = os.path.join(checkpoint_dir, fs.incr_save_dir)
else:
self._incr_save_dir = fs.incr_save_dir
if not self._incr_save_dir.endswith('/'):
self._incr_save_dir += '/'
if not gfile.IsDirectory(self._incr_save_dir):
gfile.MakeDirs(self._incr_save_dir)
elif increment_save_config.HasField('datahub'):
raise NotImplementedError('datahub increment saving is in development.')
else:
raise ValueError(
'incr_update not specified correctly, must be oneof: kafka,fs')
self._debug_save_update = increment_save_config.debug_save_update
else:
self._dense_timer = None
self._sparse_timer = None