in vis_sandbox.py [0:0]
def main_single(FLAGS):
FLAGS_OLD = FLAGS
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
if FLAGS.resume_iter != 0:
model_path = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
checkpoint = torch.load(model_path)
try:
FLAGS = checkpoint["FLAGS"]
FLAGS.resume_iter = FLAGS_OLD.resume_iter
FLAGS.neg_sample = FLAGS_OLD.neg_sample
for key in FLAGS.keys():
if "__" not in key:
FLAGS_OLD[key] = getattr(FLAGS, key)
FLAGS = FLAGS_OLD
except Exception as e:
print(e)
print("Didn't find keys in checkpoint'")
models = []
if FLAGS.ensemble > 1:
for i in range(FLAGS.ensemble):
if FLAGS.model == "transformer":
model = RotomerTransformerModel(FLAGS).eval()
elif FLAGS.model == "fc":
model = RotomerFCModel(FLAGS).eval()
elif FLAGS.model == "s2s":
model = RotomerSet2SetModel(FLAGS).eval()
elif FLAGS.model == "graph":
model = RotomerGraphModel(FLAGS).eval()
elif FLAGS.model == "s2s":
model = RotomerSet2SetModel(FLAGS).eval()
models.append(model)
else:
if FLAGS.model == "transformer":
model = RotomerTransformerModel(FLAGS).eval()
elif FLAGS.model == "fc":
model = RotomerFCModel(FLAGS).eval()
elif FLAGS.model == "s2s":
model = RotomerSet2SetModel(FLAGS).eval()
gpu = 0
world_size = 0
it = FLAGS.resume_iter
if not osp.exists(logdir):
os.makedirs(logdir)
checkpoint = None
if FLAGS.ensemble > 1:
for i, model in enumerate(models):
if FLAGS.resume_iter != 0:
model_path = osp.join(logdir, "model_{}".format(FLAGS.resume_iter - i * 1000))
checkpoint = torch.load(model_path)
try:
model.load_state_dict(checkpoint["model_state_dict"])
except Exception as e:
print("Transfer between distributed to non-distributed")
if world_size > 1:
model_state_dict = {
k.replace("module.", ""): v
for k, v in checkpoint["model_state_dict"].items()
}
else:
model_state_dict = {
k.replace("module.", ""): v
for k, v in checkpoint["model_state_dict"].items()
}
model.load_state_dict(model_state_dict)
models[i] = nn.DataParallel(model)
model = models
else:
if FLAGS.resume_iter != 0:
model_path = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
checkpoint = torch.load(model_path)
try:
model.load_state_dict(checkpoint["model_state_dict"])
except Exception as e:
print("Transfer between distributed to non-distributed")
if world_size > 1:
model_state_dict = {
k.replace("module.", ""): v
for k, v in checkpoint["model_state_dict"].items()
}
else:
model_state_dict = {
k.replace("module.", ""): v
for k, v in checkpoint["model_state_dict"].items()
}
model.load_state_dict(model_state_dict)
model = nn.DataParallel(model)
if FLAGS.cuda:
if FLAGS.ensemble > 1:
for i, model in enumerate(models):
models[i] = model.cuda(gpu)
model = models
else:
torch.cuda.set_device(gpu)
model = model.cuda(gpu)
FLAGS.multisample = 1
print("New Values of args: ", FLAGS)
with torch.no_grad():
if FLAGS.task == "pair_atom":
train_dataset = MMCIFTransformer(FLAGS, mmcif_path=MMCIF_PATH, split="train")
(
node_embeds,
node_embeds_negatives,
select_atom_idxs,
select_atom_masks,
select_chis_valids,
select_ancestors,
) = train_dataset[256]
pair_model(model, FLAGS, node_embeds)
if FLAGS.task == "pack_rotamer":
train_dataset = MMCIFTransformer(FLAGS, mmcif_path=MMCIF_PATH, split="train")
node_embed, _, _, _, _, _, = train_dataset[0]
pack_rotamer(model, FLAGS, node_embed)
if FLAGS.task == "rotamer_trial":
test_dataset = MMCIFTransformer(FLAGS, mmcif_path=MMCIF_PATH, split="test")
rotamer_trials(model, FLAGS, test_dataset)
if FLAGS.task == "new_model":
train_dataset = MMCIFTransformer(FLAGS, mmcif_path=MMCIF_PATH, split="train")
node_embed, _, _, _, _, _, = train_dataset[7]
new_model(model, FLAGS, node_embed)
if FLAGS.task == "tsne":
train_dataset = MMCIFTransformer(FLAGS, mmcif_path=MMCIF_PATH, split="train")
node_embed, _, _, _, _, _, = train_dataset[3]
make_tsne(model, FLAGS, node_embed)
else:
assert False