in train.py [0:0]
def main_single(gpu, FLAGS):
if FLAGS.slurm:
init_distributed_mode(FLAGS)
os.environ["MASTER_ADDR"] = str(FLAGS.master_addr)
os.environ["MASTER_PORT"] = str(FLAGS.port)
rank_idx = FLAGS.node_rank * FLAGS.gpus + gpu
world_size = FLAGS.nodes * FLAGS.gpus
if rank_idx == 0:
print("Values of args: ", FLAGS)
if world_size > 1:
if FLAGS.slurm:
dist.init_process_group(
backend="nccl", init_method="env://", world_size=world_size, rank=rank_idx
)
else:
dist.init_process_group(
backend="nccl",
init_method="tcp://localhost:1492",
world_size=world_size,
rank=rank_idx,
)
train_dataset = MMCIFTransformer(
FLAGS,
split="train",
rank_idx=rank_idx,
world_size=world_size,
uniform=FLAGS.uniform,
weighted_gauss=FLAGS.weighted_gauss,
gmm=FLAGS.gmm,
chi_mean=FLAGS.chi_mean,
mmcif_path=MMCIF_PATH,
)
valid_dataset = MMCIFTransformer(
FLAGS,
split="val",
rank_idx=rank_idx,
world_size=world_size,
uniform=FLAGS.uniform,
weighted_gauss=FLAGS.weighted_gauss,
gmm=FLAGS.gmm,
chi_mean=FLAGS.chi_mean,
mmcif_path=MMCIF_PATH,
)
test_dataset = MMCIFTransformer(
FLAGS,
split="test",
rank_idx=0,
world_size=1,
uniform=FLAGS.uniform,
weighted_gauss=FLAGS.weighted_gauss,
gmm=FLAGS.gmm,
chi_mean=FLAGS.chi_mean,
mmcif_path=MMCIF_PATH,
)
train_dataloader = DataLoader(
train_dataset,
num_workers=FLAGS.data_workers,
collate_fn=collate_fn_transformer,
batch_size=FLAGS.batch_size // FLAGS.multisample,
shuffle=True,
pin_memory=False,
drop_last=True,
)
valid_dataloader = DataLoader(
valid_dataset,
num_workers=0,
collate_fn=collate_fn_transformer_test,
batch_size=FLAGS.batch_size // FLAGS.multisample,
shuffle=True,
pin_memory=False,
drop_last=True,
)
test_dataloader = DataLoader(
test_dataset,
num_workers=0,
collate_fn=collate_fn_transformer_test,
batch_size=FLAGS.batch_size,
shuffle=True,
pin_memory=False,
drop_last=True,
)
train_structures = train_dataset.files
FLAGS_OLD = FLAGS
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
if FLAGS.resume_iter != 0:
model_path = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
checkpoint = torch.load(model_path)
try:
FLAGS = checkpoint["FLAGS"]
# Restore arguments to saved checkpoint values except for a select few
FLAGS.resume_iter = FLAGS_OLD.resume_iter
FLAGS.nodes = FLAGS_OLD.nodes
FLAGS.gpus = FLAGS_OLD.gpus
FLAGS.node_rank = FLAGS_OLD.node_rank
FLAGS.master_addr = FLAGS_OLD.master_addr
FLAGS.neg_sample = FLAGS_OLD.neg_sample
FLAGS.train = FLAGS_OLD.train
FLAGS.multisample = FLAGS_OLD.multisample
FLAGS.steps = FLAGS_OLD.steps
FLAGS.step_lr = FLAGS_OLD.step_lr
FLAGS.batch_size = FLAGS_OLD.batch_size
for key in dir(FLAGS):
if "__" not in key:
FLAGS_OLD[key] = getattr(FLAGS, key)
FLAGS = FLAGS_OLD
except Exception as e:
print(e)
print("Didn't find keys in checkpoint'")
if FLAGS.model == "transformer":
model = RotomerTransformerModel(FLAGS).train()
elif FLAGS.model == "fc":
model = RotomerFCModel(FLAGS).train()
elif FLAGS.model == "s2s":
model = RotomerSet2SetModel(FLAGS).train()
elif FLAGS.model == "graph":
model = RotomerGraphModel(FLAGS).train()
if FLAGS.cuda:
torch.cuda.set_device(gpu)
model = model.cuda(gpu)
optimizer = optim.Adam(model.parameters(), lr=FLAGS.start_lr, betas=(0.99, 0.999))
if FLAGS.gpus > 1:
sync_model(model)
logger = TensorBoardOutputFormat(logdir)
it = FLAGS.resume_iter
if not osp.exists(logdir):
os.makedirs(logdir)
checkpoint = None
if FLAGS.resume_iter != 0:
model_path = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
checkpoint = torch.load(model_path)
try:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
model.load_state_dict(checkpoint["model_state_dict"])
except Exception as e:
print("Transfer between distributed to non-distributed")
model_state_dict = {
k.replace("module.", ""): v for k, v in checkpoint["model_state_dict"].items()
}
model.load_state_dict(model_state_dict)
pytorch_total_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
if rank_idx == 0:
print("New Values of args: ", FLAGS)
print("Number of parameters for models", pytorch_total_params)
if FLAGS.train:
train(
train_dataloader,
valid_dataloader,
logger,
model,
optimizer,
FLAGS,
logdir,
rank_idx,
train_structures,
checkpoint=checkpoint,
)
else:
test(test_dataloader, model, FLAGS, logdir)