def _send_sparse()

in easy_rec/python/utils/estimator_utils.py [0:0]


  def _send_sparse(self, global_step, session):
    sparse_train_vars = ops.get_collection(constant.SPARSE_UPDATE_VARIABLES)
    sparse_res = session.run(self._sparse_indices + self._sparse_values)
    msg_num = int(len(sparse_res) / 2)

    sel_ids = [i for i in range(msg_num) if len(sparse_res[i]) > 0]
    sparse_key_res = [sparse_res[i] for i in sel_ids]
    sparse_val_res = [sparse_res[i + msg_num] for i in sel_ids]
    sparse_train_vars = [sparse_train_vars[i][0] for i in sel_ids]

    sel_embed_ids = [
        self._sparse_name_to_ids[x.name] for x in sparse_train_vars
    ]

    msg_num = len(sel_ids)

    if msg_num == 0:
      logging.warning('there are no sparse updates, will skip this send: %d' %
                      global_step)
      return

    # build msg header
    # 1 means sparse update messages
    msg_header = [1, msg_num, global_step]
    for tmp_id, tmp_key in zip(sel_embed_ids, sparse_key_res):
      msg_header.append(tmp_id)
      msg_header.append(len(tmp_key))
    bytes_buf = np.array(msg_header, dtype=np.int32).tobytes()

    # build msg body
    for tmp_id, tmp_key, tmp_val, tmp_var in zip(sel_embed_ids, sparse_key_res,
                                                 sparse_val_res,
                                                 sparse_train_vars):
      # for non kv embedding variables, add partition offset to tmp_key
      if 'EmbeddingVariable' not in str(type(tmp_var)):
        if tmp_var._save_slice_info is not None:
          tmp_key += tmp_var._save_slice_info.var_offset[0]
      bytes_buf += tmp_key.tobytes()
      bytes_buf += tmp_val.tobytes()
    if self._kafka_producer is not None:
      msg_key = 'sparse_update_%d' % global_step
      send_res = self._kafka_producer.send(
          self._topic, bytes_buf, key=msg_key.encode('utf-8'))
      logging.info('kafka send sparse: %d %s' %
                   (global_step, send_res.exception))

    if self._incr_save_dir is not None:
      save_path = os.path.join(self._incr_save_dir,
                               'sparse_update_%d' % global_step)
      with gfile.GFile(save_path, 'wb') as fout:
        fout.write(bytes_buf)
      save_flag = save_path + '.done'
      with gfile.GFile(save_flag, 'w') as fout:
        fout.write('sparse_update_%d' % global_step)

    if self._debug_save_update and self._incr_save_dir is None:
      base_dir, _ = os.path.split(self._save_path)
      incr_save_dir = os.path.join(base_dir, 'incr_save/')
      if not gfile.Exists(incr_save_dir):
        gfile.MakeDirs(incr_save_dir)
      save_path = os.path.join(incr_save_dir, 'sparse_update_%d' % global_step)
      with gfile.GFile(save_path, 'wb') as fout:
        fout.write(bytes_buf)

    logging.info(
        'global_step=%d, increment update sparse variables, msg_num=%d, msg_size=%d'
        % (global_step, msg_num, len(bytes_buf)))