in easy_rec/python/utils/estimator_utils.py [0:0]
def _send_sparse(self, global_step, session):
sparse_train_vars = ops.get_collection(constant.SPARSE_UPDATE_VARIABLES)
sparse_res = session.run(self._sparse_indices + self._sparse_values)
msg_num = int(len(sparse_res) / 2)
sel_ids = [i for i in range(msg_num) if len(sparse_res[i]) > 0]
sparse_key_res = [sparse_res[i] for i in sel_ids]
sparse_val_res = [sparse_res[i + msg_num] for i in sel_ids]
sparse_train_vars = [sparse_train_vars[i][0] for i in sel_ids]
sel_embed_ids = [
self._sparse_name_to_ids[x.name] for x in sparse_train_vars
]
msg_num = len(sel_ids)
if msg_num == 0:
logging.warning('there are no sparse updates, will skip this send: %d' %
global_step)
return
# build msg header
# 1 means sparse update messages
msg_header = [1, msg_num, global_step]
for tmp_id, tmp_key in zip(sel_embed_ids, sparse_key_res):
msg_header.append(tmp_id)
msg_header.append(len(tmp_key))
bytes_buf = np.array(msg_header, dtype=np.int32).tobytes()
# build msg body
for tmp_id, tmp_key, tmp_val, tmp_var in zip(sel_embed_ids, sparse_key_res,
sparse_val_res,
sparse_train_vars):
# for non kv embedding variables, add partition offset to tmp_key
if 'EmbeddingVariable' not in str(type(tmp_var)):
if tmp_var._save_slice_info is not None:
tmp_key += tmp_var._save_slice_info.var_offset[0]
bytes_buf += tmp_key.tobytes()
bytes_buf += tmp_val.tobytes()
if self._kafka_producer is not None:
msg_key = 'sparse_update_%d' % global_step
send_res = self._kafka_producer.send(
self._topic, bytes_buf, key=msg_key.encode('utf-8'))
logging.info('kafka send sparse: %d %s' %
(global_step, send_res.exception))
if self._incr_save_dir is not None:
save_path = os.path.join(self._incr_save_dir,
'sparse_update_%d' % global_step)
with gfile.GFile(save_path, 'wb') as fout:
fout.write(bytes_buf)
save_flag = save_path + '.done'
with gfile.GFile(save_flag, 'w') as fout:
fout.write('sparse_update_%d' % global_step)
if self._debug_save_update and self._incr_save_dir is None:
base_dir, _ = os.path.split(self._save_path)
incr_save_dir = os.path.join(base_dir, 'incr_save/')
if not gfile.Exists(incr_save_dir):
gfile.MakeDirs(incr_save_dir)
save_path = os.path.join(incr_save_dir, 'sparse_update_%d' % global_step)
with gfile.GFile(save_path, 'wb') as fout:
fout.write(bytes_buf)
logging.info(
'global_step=%d, increment update sparse variables, msg_num=%d, msg_size=%d'
% (global_step, msg_num, len(bytes_buf)))