in coinrun/tb_utils.py [0:0]
def __init__(self, sess):
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
clean_tb_dir()
tb_writer = tf.summary.FileWriter(Config.TB_DIR + '/' + Config.RUN_ID + '_' + str(rank), sess.graph)
total_steps = [0]
should_log = (rank == 0 or Config.LOG_ALL_MPI)
if should_log:
hyperparams = np.array(Config.get_arg_text())
hyperparams_tensor = tf.constant(hyperparams)
summary_op = tf.summary.text("hyperparameters info", hyperparams_tensor)
summary = sess.run(summary_op)
tb_writer.add_summary(summary)
def add_summary(_merged, interval=1):
if should_log:
total_steps[0] += 1
if total_steps[0] % interval == 0:
tb_writer.add_summary(_merged, total_steps[0])
tb_writer.flush()
tuples = []
def make_scalar_graph(name):
scalar_ph = tf.placeholder(name='scalar_' + name, dtype=tf.float32)
scalar_summary = tf.summary.scalar(name, scalar_ph)
merged = tf.summary.merge([scalar_summary])
tuples.append((scalar_ph, merged))
name_dict = {}
curr_name_idx = [0]
def log_scalar(x, name, step=-1):
if not name in name_dict:
name_dict[name] = curr_name_idx[0]
tf_name = (name + '_' + Config.RUN_ID) if curr_name_idx[0] == 0 else name
make_scalar_graph(tf_name)
curr_name_idx[0] += 1
idx = name_dict[name]
scalar_ph, merged = tuples[idx]
if should_log:
if step == -1:
step = total_steps[0]
total_steps[0] += 1
_merged = sess.run(merged, {scalar_ph: x})
tb_writer.add_summary(_merged, step)
tb_writer.flush()
self.add_summary = add_summary
self.log_scalar = log_scalar