def main()

in distributed_training/train_pytorch_smdataparallel_maskrcnn.py [0:0]


def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=dist.get_local_rank())
    parser.add_argument(
        "--seed",
        help="manually set random seed for torch",
        type=int,
        default=99
    )
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    parser.add_argument(
        "--bucket-cap-mb",
        dest="bucket_cap_mb",
        help="specify bucket size for SMDataParallel",
        default=25,
        type=int,
    )
    parser.add_argument(
        "--data-dir",
        dest="data_dir",
        help="Absolute path of dataset ",
        type=str,
        default=None
    )
    parser.add_argument(
        "--dtype",
        dest="dtype"
    )
    parser.add_argument(
        "--spot_ckpt",
        default=None
    )


    args = parser.parse_args()
    keys = list(os.environ.keys())
    args.data_dir = os.environ['SM_CHANNEL_TRAIN'] if 'SM_CHANNEL_TRAIN' in keys else args.data_dir
    print("dataset dir: ", args.data_dir)


    # Set seed to reduce randomness
    random.seed(args.seed + dist.get_local_rank())
    np.random.seed(args.seed + dist.get_local_rank())
    torch.manual_seed(args.seed + dist.get_local_rank())
    torch.cuda.manual_seed(args.seed + dist.get_local_rank())

    # num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    num_gpus = dist.get_world_size()
    args.distributed = num_gpus > 1

    if args.distributed:
        # SMDataParallel: Pin each GPU to a single SMDataParallel process. 
        torch.cuda.set_device(args.local_rank)
        # torch.distributed.init_process_group(
        #     backend="nccl", init_method="env://"
        # )
        #synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.DTYPE=args.dtype
    # grab checkpoint file to start from
    os.system(f"aws s3 cp {args.spot_ckpt} /opt/ml/checkpoints/{args.spot_ckpt.split('/')[-1]}")
    cfg.MODEL.WEIGHT = f"/opt/ml/checkpoints/{args.spot_ckpt.split('/')[-1]}"
    cfg.freeze()
    print ("CONFIG")
    print (cfg)

    output_dir = cfg.OUTPUT_DIR
    if output_dir:
        mkdir(output_dir)

    logger = setup_logger("maskrcnn_benchmark", output_dir, dist.get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    model = train(cfg, args)

    if not args.skip_test:
        if not cfg.PER_EPOCH_EVAL:
            test_model(cfg, model, args)