def main_single()

in vis_sandbox.py [0:0]


def main_single(FLAGS):
    FLAGS_OLD = FLAGS

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)

    if FLAGS.resume_iter != 0:
        model_path = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
        checkpoint = torch.load(model_path)
        try:
            FLAGS = checkpoint["FLAGS"]

            FLAGS.resume_iter = FLAGS_OLD.resume_iter
            FLAGS.neg_sample = FLAGS_OLD.neg_sample

            for key in FLAGS.keys():
                if "__" not in key:
                    FLAGS_OLD[key] = getattr(FLAGS, key)

            FLAGS = FLAGS_OLD
        except Exception as e:
            print(e)
            print("Didn't find keys in checkpoint'")

    models = []
    if FLAGS.ensemble > 1:
        for i in range(FLAGS.ensemble):
            if FLAGS.model == "transformer":
                model = RotomerTransformerModel(FLAGS).eval()
            elif FLAGS.model == "fc":
                model = RotomerFCModel(FLAGS).eval()
            elif FLAGS.model == "s2s":
                model = RotomerSet2SetModel(FLAGS).eval()
            elif FLAGS.model == "graph":
                model = RotomerGraphModel(FLAGS).eval()
            elif FLAGS.model == "s2s":
                model = RotomerSet2SetModel(FLAGS).eval()
            models.append(model)
    else:
        if FLAGS.model == "transformer":
            model = RotomerTransformerModel(FLAGS).eval()
        elif FLAGS.model == "fc":
            model = RotomerFCModel(FLAGS).eval()
        elif FLAGS.model == "s2s":
            model = RotomerSet2SetModel(FLAGS).eval()

    gpu = 0
    world_size = 0

    it = FLAGS.resume_iter

    if not osp.exists(logdir):
        os.makedirs(logdir)

    checkpoint = None

    if FLAGS.ensemble > 1:
        for i, model in enumerate(models):
            if FLAGS.resume_iter != 0:
                model_path = osp.join(logdir, "model_{}".format(FLAGS.resume_iter - i * 1000))
                checkpoint = torch.load(model_path)
                try:
                    model.load_state_dict(checkpoint["model_state_dict"])
                except Exception as e:
                    print("Transfer between distributed to non-distributed")

                    if world_size > 1:
                        model_state_dict = {
                            k.replace("module.", ""): v
                            for k, v in checkpoint["model_state_dict"].items()
                        }
                    else:
                        model_state_dict = {
                            k.replace("module.", ""): v
                            for k, v in checkpoint["model_state_dict"].items()
                        }
                    model.load_state_dict(model_state_dict)

            models[i] = nn.DataParallel(model)
        model = models
    else:
        if FLAGS.resume_iter != 0:
            model_path = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
            checkpoint = torch.load(model_path)
            try:
                model.load_state_dict(checkpoint["model_state_dict"])
            except Exception as e:
                print("Transfer between distributed to non-distributed")

                if world_size > 1:
                    model_state_dict = {
                        k.replace("module.", ""): v
                        for k, v in checkpoint["model_state_dict"].items()
                    }
                else:
                    model_state_dict = {
                        k.replace("module.", ""): v
                        for k, v in checkpoint["model_state_dict"].items()
                    }
                model.load_state_dict(model_state_dict)
            model = nn.DataParallel(model)

    if FLAGS.cuda:
        if FLAGS.ensemble > 1:
            for i, model in enumerate(models):
                models[i] = model.cuda(gpu)

            model = models
        else:
            torch.cuda.set_device(gpu)
            model = model.cuda(gpu)

    FLAGS.multisample = 1
    print("New Values of args: ", FLAGS)

    with torch.no_grad():
        if FLAGS.task == "pair_atom":
            train_dataset = MMCIFTransformer(FLAGS, mmcif_path=MMCIF_PATH, split="train")
            (
                node_embeds,
                node_embeds_negatives,
                select_atom_idxs,
                select_atom_masks,
                select_chis_valids,
                select_ancestors,
            ) = train_dataset[256]
            pair_model(model, FLAGS, node_embeds)
        if FLAGS.task == "pack_rotamer":
            train_dataset = MMCIFTransformer(FLAGS, mmcif_path=MMCIF_PATH, split="train")
            node_embed, _, _, _, _, _, = train_dataset[0]
            pack_rotamer(model, FLAGS, node_embed)
        if FLAGS.task == "rotamer_trial":
            test_dataset = MMCIFTransformer(FLAGS, mmcif_path=MMCIF_PATH, split="test")
            rotamer_trials(model, FLAGS, test_dataset)
        if FLAGS.task == "new_model":
            train_dataset = MMCIFTransformer(FLAGS, mmcif_path=MMCIF_PATH, split="train")
            node_embed, _, _, _, _, _, = train_dataset[7]
            new_model(model, FLAGS, node_embed)
        if FLAGS.task == "tsne":
            train_dataset = MMCIFTransformer(FLAGS, mmcif_path=MMCIF_PATH, split="train")
            node_embed, _, _, _, _, _, = train_dataset[3]
            make_tsne(model, FLAGS, node_embed)
        else:
            assert False