def __init__()

in easy_rec/python/utils/estimator_utils.py [0:0]


  def __init__(self,
               checkpoint_dir,
               save_secs=None,
               save_steps=None,
               saver=None,
               checkpoint_basename='model.ckpt',
               scaffold=None,
               listeners=None,
               write_graph=True,
               data_offset_var=None,
               increment_save_config=None):
    """Initializes a `CheckpointSaverHook`.

    Args:
      checkpoint_dir: `str`, base directory for the checkpoint files.
      save_secs: `int`, save every N secs.
      save_steps: `int`, save every N steps.
      saver: `Saver` object, used for saving.
      checkpoint_basename: `str`, base name for the checkpoint files.
      scaffold: `Scaffold`, use to get saver object.
      listeners: List of `CheckpointSaverListener` subclass instances.
        Used for callbacks that run immediately before or after this hook saves
        the checkpoint.
      write_graph: whether to save graph.pbtxt.
      data_offset_var: data offset variable.
      increment_save_config: parameters for saving increment checkpoints.

    Raises:
      ValueError: One of `save_steps` or `save_secs` should be set.
      ValueError: At most one of saver or scaffold should be set.
    """
    super(CheckpointSaverHook, self).__init__(
        checkpoint_dir,
        save_secs=save_secs,
        save_steps=save_steps,
        saver=saver,
        checkpoint_basename=checkpoint_basename,
        scaffold=scaffold,
        listeners=listeners)
    self._cuda_profile_start = 0
    self._cuda_profile_stop = 0
    self._steps_per_run = 1
    self._write_graph = write_graph
    self._data_offset_var = data_offset_var

    self._task_idx, self._task_num = get_task_index_and_num()

    if increment_save_config is not None:
      self._kafka_timeout_ms = os.environ.get('KAFKA_TIMEOUT', 600) * 1000
      logging.info('KAFKA_TIMEOUT: %dms' % self._kafka_timeout_ms)
      self._kafka_max_req_size = os.environ.get('KAFKA_MAX_REQ_SIZE',
                                                1024 * 1024 * 64)
      logging.info('KAFKA_MAX_REQ_SIZE: %d' % self._kafka_max_req_size)
      self._kafka_max_msg_size = os.environ.get('KAFKA_MAX_MSG_SIZE',
                                                1024 * 1024 * 1024)
      logging.info('KAFKA_MAX_MSG_SIZE: %d' % self._kafka_max_msg_size)

      self._dense_name_to_ids = embedding_utils.get_dense_name_to_ids()
      self._sparse_name_to_ids = embedding_utils.get_sparse_name_to_ids()

      with gfile.GFile(
          os.path.join(checkpoint_dir, constant.DENSE_UPDATE_VARIABLES),
          'w') as fout:
        json.dump(self._dense_name_to_ids, fout, indent=2)

      save_secs = increment_save_config.dense_save_secs
      save_steps = increment_save_config.dense_save_steps
      self._dense_timer = SecondOrStepTimer(
          every_secs=save_secs if save_secs > 0 else None,
          every_steps=save_steps if save_steps > 0 else None)
      save_secs = increment_save_config.sparse_save_secs
      save_steps = increment_save_config.sparse_save_steps
      self._sparse_timer = SecondOrStepTimer(
          every_secs=save_secs if save_secs > 0 else None,
          every_steps=save_steps if save_steps > 0 else None)

      self._dense_timer.update_last_triggered_step(0)
      self._sparse_timer.update_last_triggered_step(0)

      self._sparse_indices = []
      self._sparse_values = []
      sparse_train_vars = ops.get_collection(constant.SPARSE_UPDATE_VARIABLES)
      for sparse_var, indice_dtype in sparse_train_vars:
        with ops.control_dependencies([tf.train.get_global_step()]):
          with ops.colocate_with(sparse_var):
            sparse_indice = get_sparse_indices(
                var_name=sparse_var.op.name, ktype=indice_dtype)
          # sparse_indice = sparse_indice.global_indices
        self._sparse_indices.append(sparse_indice)
        if 'EmbeddingVariable' in str(type(sparse_var)):
          self._sparse_values.append(
              kv_resource_incr_gather(
                  sparse_var._handle, sparse_indice,
                  np.zeros(sparse_var.shape.as_list(), dtype=np.float32)))
          # sparse_var.sparse_read(sparse_indice))
        else:
          self._sparse_values.append(
              array_ops.gather(sparse_var, sparse_indice))

      self._kafka_producer = None
      self._incr_save_dir = None
      if increment_save_config.HasField('kafka'):
        self._topic = increment_save_config.kafka.topic
        logging.info('increment save topic: %s' % self._topic)

        admin_clt = KafkaAdminClient(
            bootstrap_servers=increment_save_config.kafka.server,
            request_timeout_ms=self._kafka_timeout_ms,
            api_version_auto_timeout_ms=self._kafka_timeout_ms)
        if self._topic not in admin_clt.list_topics():
          admin_clt.create_topics(
              new_topics=[
                  NewTopic(
                      name=self._topic,
                      num_partitions=1,
                      replication_factor=1,
                      topic_configs={
                          'max.message.bytes': self._kafka_max_msg_size
                      })
              ],
              validate_only=False)
        logging.info('create increment save topic: %s' % self._topic)
        admin_clt.close()

        servers = increment_save_config.kafka.server.split(',')
        self._kafka_producer = KafkaProducer(
            bootstrap_servers=servers,
            max_request_size=self._kafka_max_req_size,
            api_version_auto_timeout_ms=self._kafka_timeout_ms,
            request_timeout_ms=self._kafka_timeout_ms)
      elif increment_save_config.HasField('fs'):
        fs = increment_save_config.fs
        if fs.relative:
          self._incr_save_dir = os.path.join(checkpoint_dir, fs.incr_save_dir)
        else:
          self._incr_save_dir = fs.incr_save_dir
        if not self._incr_save_dir.endswith('/'):
          self._incr_save_dir += '/'
        if not gfile.IsDirectory(self._incr_save_dir):
          gfile.MakeDirs(self._incr_save_dir)
      elif increment_save_config.HasField('datahub'):
        raise NotImplementedError('datahub increment saving is in development.')
      else:
        raise ValueError(
            'incr_update not specified correctly, must be oneof: kafka,fs')

      self._debug_save_update = increment_save_config.debug_save_update
    else:
      self._dense_timer = None
      self._sparse_timer = None