in utils/main_utils.py [0:0]
def prep_environment(args, cfg):
from torch.utils.tensorboard import SummaryWriter
# Prepare loggers (must be configured after initialize_distributed_backend())
model_dir = '{}/{}'.format(cfg['model']['model_dir'], cfg['model']['name'])
if args.rank == 0:
prep_output_folder(model_dir, False)
log_fn = '{}/train.log'.format(model_dir)
logger = Logger(quiet=args.quiet, log_fn=log_fn, rank=args.rank)
logger.add_line(str(datetime.datetime.now()))
if any(['SLURM' in env for env in list(os.environ.keys())]):
logger.add_line("=" * 30 + " SLURM " + "=" * 30)
for env in os.environ.keys():
if 'SLURM' in env:
logger.add_line('{:30}: {}'.format(env, os.environ[env]))
logger.add_line("=" * 30 + " Config " + "=" * 30)
def print_dict(d, ident=''):
for k in d:
if isinstance(d[k], dict):
logger.add_line("{}{}".format(ident, k))
print_dict(d[k], ident=' '+ident)
else:
logger.add_line("{}{}: {}".format(ident, k, str(d[k])))
print_dict(cfg)
logger.add_line("=" * 30 + " Args " + "=" * 30)
for k in args.__dict__:
logger.add_line('{:30} {}'.format(k, args.__dict__[k]))
tb_writter = None
if cfg['log2tb'] and args.rank == 0:
tb_dir = '{}/tensorboard'.format(model_dir)
os.system('mkdir -p {}'.format(tb_dir))
tb_writter = SummaryWriter(tb_dir)
return logger, tb_writter, model_dir