in train/learner.py [0:0]
def _main(flags, rank=0, barrier=None, gossip_buffer=None, stop_event=None):
torch.random.manual_seed(flags.seed)
if flags.logdir:
# Write meta.json file with some information on our setup.
metadata = {
"flags": OmegaConf.to_container(flags),
"env": os.environ.copy(),
"date_start": time.strftime("%Y-%m-%d %H:%M:%S"),
}
try:
import git
except ImportError:
pass
else:
try:
repo = git.Repo(search_parent_directories=True)
metadata["git"] = {
"commit": repo.commit().hexsha,
"is_dirty": repo.is_dirty(),
"path": repo.git_dir,
}
if not repo.head.is_detached:
metadata["git"]["branch"] = repo.active_branch.name
except git.InvalidGitRepositoryError:
pass
if "git" not in metadata:
logging.warn("Couldn't determine git data.")
with open(os.path.join(flags.logdir, "meta.json"), "w") as f:
json.dump(metadata, f, indent=2, sort_keys=True)
if flags.mode == "train":
learner_loop(flags, rank, barrier, gossip_buffer, stop_event)
elif flags.mode == "trace":
trace(flags)
else:
# Test mode unsupported in learner. We rely on "local" testing
# by tracing the model and running it via C++ in mvfst without RPC.
raise RuntimeError("Unsupported mode {}".format(flags.mode))
if flags.logdir:
# Write an empty "OK" flag to indicate success.
with (Path(flags.logdir) / "OK").open("w"):
pass