utils.py (181 lines of code) (raw):
import os
import itertools
import json
import tempfile
import numpy as np
import tensorflow as tf
import blocksparse as bs
import time
import subprocess
from mpi_utils import mpi_rank
def logger(log_prefix):
'Prints the arguments out to stdout, .txt, and .jsonl files'
jsonl_path = f'{log_prefix}.jsonl'
txt_path = f'{log_prefix}.txt'
def log(*args, pprint=False, **kwargs):
if mpi_rank() != 0:
return
t = time.ctime()
argdict = {'time': t}
if len(args) > 0:
argdict['message'] = ' '.join([str(x) for x in args])
argdict.update(kwargs)
txt_str = []
args_iter = sorted(argdict) if pprint else argdict
for k in args_iter:
val = argdict[k]
if isinstance(val, np.ndarray):
val = val.tolist()
elif isinstance(val, np.integer):
val = int(val)
elif isinstance(val, np.floating):
val = float(val)
argdict[k] = val
if isinstance(val, float):
if k == 'lr':
val = f'{val:.6f}'
else:
val = f'{val:.4f}'
txt_str.append(f'{k}: {val}')
txt_str = ', '.join(txt_str)
if pprint:
json_str = json.dumps(argdict, sort_keys=True)
txt_str = json.dumps(argdict, sort_keys=True, indent=4)
else:
json_str = json.dumps(argdict)
print(txt_str, flush=True)
with open(txt_path, "a+") as f:
print(txt_str, file=f, flush=True)
with open(jsonl_path, "a+") as f:
print(json_str, file=f, flush=True)
return log
def go_over(choices):
return itertools.product(*[range(n) for n in choices])
def get_git_revision():
git_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD'])
return git_hash.strip().decode('utf-8')
def shape_list(x):
"""
deal with dynamic shape in tensorflow cleanly
"""
ps = x.get_shape().as_list()
ts = tf.shape(x)
return [ts[i] if ps[i] is None else ps[i] for i in range(len(ps))]
def rsync_data(from_path, to_path):
subprocess.check_output(['rsync', '-r', from_path, to_path,
'--update'])
def maybe_download(path):
'''If a path is a gsutil path, download it and return the local link,
otherwise return link'''
if not path.startswith('gs://'):
return path
local_dest = tempfile.mkstemp()[1]
subprocess.check_output(['gsutil', '-m', 'cp', path, local_dest])
return local_dest
def upload_to_gcp(from_path, to_path, is_async=False):
if is_async:
cmd = f'bash -exec -c "gsutil -m rsync -r {from_path} {to_path}"&'
subprocess.call(cmd, shell=True, stderr=subprocess.DEVNULL)
else:
subprocess.check_output(['gsutil', '-m', 'rsync', from_path, to_path])
def check_identical(from_path, to_path):
try:
subprocess.check_output(['git', 'diff', '--no-index', '--quiet',
from_path, to_path])
return True
except subprocess.CalledProcessError:
return False
def wait_until_synced(from_path, to_path):
while True:
if check_identical(from_path, to_path):
break
else:
time.sleep(5)
def is_gcp():
try:
subprocess.check_output(['curl', '-s',
'metadata.google.internal', '-i'])
return True
except subprocess.CalledProcessError:
return False
def backup_files(save_dir, save_dir_gcp, path=None):
if mpi_rank() == 0:
if not path:
print(f'Backing up {save_dir} to {save_dir_gcp}',
'Will execute silently in another thread')
upload_to_gcp(save_dir, save_dir_gcp, is_async=True)
else:
upload_to_gcp(path, save_dir_gcp, is_async=True)
def log_gradient_values(grads, variables, global_step, model_dir):
loggrads = []
with tf.name_scope("log_gradient_values"):
for i, (grad, param) in enumerate(zip(grads, variables)):
name = param.op.name + "_" + "_".join(
str(x) for x in param.shape.as_list())
loggrads.append(bs.log_stats(
grad, step=global_step, name=name,
logfile=os.path.join(model_dir, 'grad_stats.txt')))
return loggrads
def tf_print(t, name, summarize=10, first_n=None, mv=False, maxmin=False):
# Useful for debugging!
axes = [i for i in range(len(t.shape))]
if mv:
m, v = tf.nn.moments(t, axes=axes)
if maxmin:
maxi = tf.reduce_max(t)
mini = tf.reduce_min(t)
prefix = f'{tf.get_variable_scope().name}-{name}'
with tf.device('/cpu:0'):
if mv:
t = tf.Print(t, [tf.shape(t), m, v], prefix,
summarize=summarize, first_n=first_n)
elif maxmin:
t = tf.Print(t, [tf.shape(t), mini, maxi, t], prefix,
summarize=summarize, first_n=first_n)
else:
t = tf.Print(t, [tf.shape(t), t], prefix,
summarize=summarize, first_n=first_n)
return t
def get_variables(trainable=False):
if trainable:
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
else:
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
return variables
def load_variables(sess, weights, ignore=None, trainable=False, ema=True):
'''ema refers to whether the exponential moving averaged weights are used to
initialize the true weights or not.'''
weights = {os.path.normpath(key): value for key, value in weights.items()}
ops = []
feed_dict = {}
if ema:
gvs_map = {v.name: v for v in tf.global_variables()}
for i, var in enumerate(get_variables(trainable=trainable)):
var_name = os.path.normpath(var.name)
if ignore:
do_not_load = False
for ignore_substr in ignore:
if ignore_substr in var_name:
do_not_load = True
if do_not_load:
continue
ph = tf.placeholder(dtype=var.dtype, shape=var.shape)
ops.append(var.assign(ph))
if ema:
ema_name = f'{var_name[:-2]}/Ema/ema:0'
# We assign the EMA value to the current value
try:
feed_dict[ph] = weights[ema_name]
except KeyError:
print(f'warning: ema var not found for {var_name}')
feed_dict[ph] = weights[var_name]
# We also assign the EMA value to the current EMA, which will otherwise
# use the initialized value of the variable (random)
ema_var = gvs_map[ema_name]
ph = tf.placeholder(dtype=ema_var.dtype, shape=ema_var.shape)
ops.append(ema_var.assign(ph))
feed_dict[ph] = weights[ema_name]
else:
feed_dict[ph] = weights[var_name]
sess.run(ops, feed_dict)
def save_params(sess, path):
if mpi_rank() == 0:
tf_vars = dict(zip([var.name for var in get_variables()],
sess.run(get_variables())))
np.savez(path + '.npz', **tf_vars)
def load_variables_from_file(sess, path, ignore=None, trainable=False, ema=True):
weights = dict(np.load(path))
load_variables(sess, weights, ignore, trainable=trainable, ema=ema)