def main_single()

in train.py [0:0]


def main_single(gpu, FLAGS):
    if FLAGS.slurm:
        init_distributed_mode(FLAGS)

    os.environ["MASTER_ADDR"] = str(FLAGS.master_addr)
    os.environ["MASTER_PORT"] = str(FLAGS.port)

    rank_idx = FLAGS.node_rank * FLAGS.gpus + gpu
    world_size = FLAGS.nodes * FLAGS.gpus

    if rank_idx == 0:
        print("Values of args: ", FLAGS)

    if world_size > 1:
        if FLAGS.slurm:
            dist.init_process_group(
                backend="nccl", init_method="env://", world_size=world_size, rank=rank_idx
            )
        else:
            dist.init_process_group(
                backend="nccl",
                init_method="tcp://localhost:1492",
                world_size=world_size,
                rank=rank_idx,
            )

    train_dataset = MMCIFTransformer(
        FLAGS,
        split="train",
        rank_idx=rank_idx,
        world_size=world_size,
        uniform=FLAGS.uniform,
        weighted_gauss=FLAGS.weighted_gauss,
        gmm=FLAGS.gmm,
        chi_mean=FLAGS.chi_mean,
        mmcif_path=MMCIF_PATH,
    )
    valid_dataset = MMCIFTransformer(
        FLAGS,
        split="val",
        rank_idx=rank_idx,
        world_size=world_size,
        uniform=FLAGS.uniform,
        weighted_gauss=FLAGS.weighted_gauss,
        gmm=FLAGS.gmm,
        chi_mean=FLAGS.chi_mean,
        mmcif_path=MMCIF_PATH,
    )
    test_dataset = MMCIFTransformer(
        FLAGS,
        split="test",
        rank_idx=0,
        world_size=1,
        uniform=FLAGS.uniform,
        weighted_gauss=FLAGS.weighted_gauss,
        gmm=FLAGS.gmm,
        chi_mean=FLAGS.chi_mean,
        mmcif_path=MMCIF_PATH,
    )
    train_dataloader = DataLoader(
        train_dataset,
        num_workers=FLAGS.data_workers,
        collate_fn=collate_fn_transformer,
        batch_size=FLAGS.batch_size // FLAGS.multisample,
        shuffle=True,
        pin_memory=False,
        drop_last=True,
    )
    valid_dataloader = DataLoader(
        valid_dataset,
        num_workers=0,
        collate_fn=collate_fn_transformer_test,
        batch_size=FLAGS.batch_size // FLAGS.multisample,
        shuffle=True,
        pin_memory=False,
        drop_last=True,
    )
    test_dataloader = DataLoader(
        test_dataset,
        num_workers=0,
        collate_fn=collate_fn_transformer_test,
        batch_size=FLAGS.batch_size,
        shuffle=True,
        pin_memory=False,
        drop_last=True,
    )

    train_structures = train_dataset.files

    FLAGS_OLD = FLAGS

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)

    if FLAGS.resume_iter != 0:
        model_path = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
        checkpoint = torch.load(model_path)
        try:
            FLAGS = checkpoint["FLAGS"]

            # Restore arguments to saved checkpoint values except for a select few
            FLAGS.resume_iter = FLAGS_OLD.resume_iter
            FLAGS.nodes = FLAGS_OLD.nodes
            FLAGS.gpus = FLAGS_OLD.gpus
            FLAGS.node_rank = FLAGS_OLD.node_rank
            FLAGS.master_addr = FLAGS_OLD.master_addr
            FLAGS.neg_sample = FLAGS_OLD.neg_sample
            FLAGS.train = FLAGS_OLD.train
            FLAGS.multisample = FLAGS_OLD.multisample
            FLAGS.steps = FLAGS_OLD.steps
            FLAGS.step_lr = FLAGS_OLD.step_lr
            FLAGS.batch_size = FLAGS_OLD.batch_size

            for key in dir(FLAGS):
                if "__" not in key:
                    FLAGS_OLD[key] = getattr(FLAGS, key)

            FLAGS = FLAGS_OLD
        except Exception as e:
            print(e)
            print("Didn't find keys in checkpoint'")

    if FLAGS.model == "transformer":
        model = RotomerTransformerModel(FLAGS).train()
    elif FLAGS.model == "fc":
        model = RotomerFCModel(FLAGS).train()
    elif FLAGS.model == "s2s":
        model = RotomerSet2SetModel(FLAGS).train()
    elif FLAGS.model == "graph":
        model = RotomerGraphModel(FLAGS).train()

    if FLAGS.cuda:
        torch.cuda.set_device(gpu)
        model = model.cuda(gpu)

    optimizer = optim.Adam(model.parameters(), lr=FLAGS.start_lr, betas=(0.99, 0.999))

    if FLAGS.gpus > 1:
        sync_model(model)

    logger = TensorBoardOutputFormat(logdir)

    it = FLAGS.resume_iter

    if not osp.exists(logdir):
        os.makedirs(logdir)

    checkpoint = None
    if FLAGS.resume_iter != 0:
        model_path = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
        checkpoint = torch.load(model_path)
        try:
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            model.load_state_dict(checkpoint["model_state_dict"])
        except Exception as e:
            print("Transfer between distributed to non-distributed")

            model_state_dict = {
                k.replace("module.", ""): v for k, v in checkpoint["model_state_dict"].items()
            }
            model.load_state_dict(model_state_dict)

    pytorch_total_params = sum([p.numel() for p in model.parameters() if p.requires_grad])

    if rank_idx == 0:
        print("New Values of args: ", FLAGS)
        print("Number of parameters for models", pytorch_total_params)

    if FLAGS.train:
        train(
            train_dataloader,
            valid_dataloader,
            logger,
            model,
            optimizer,
            FLAGS,
            logdir,
            rank_idx,
            train_structures,
            checkpoint=checkpoint,
        )
    else:
        test(test_dataloader, model, FLAGS, logdir)