in distributed_training/train_pytorch_single_maskrcnn.py [0:0]
def main():
parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
parser.add_argument(
"--config-file",
default="",
metavar="FILE",
help="path to config file",
type=str,
)
# parser.add_argument("--local_rank", type=int, default=dist.get_local_rank())
parser.add_argument(
"--seed",
help="manually set random seed for torch",
type=int,
default=99
)
parser.add_argument(
"--skip-test",
dest="skip_test",
help="Do not test the final model",
action="store_true",
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
parser.add_argument(
"--bucket-cap-mb",
dest="bucket_cap_mb",
help="specify bucket size for SMDataParallel",
default=25,
type=int,
)
parser.add_argument(
"--data-dir",
dest="data_dir",
help="Absolute path of dataset ",
type=str,
default=None
)
parser.add_argument(
"--dtype",
dest="dtype"
)
parser.add_argument(
"--spot_ckpt",
default=None
)
args = parser.parse_args()
keys = list(os.environ.keys())
args.data_dir = os.environ['SM_CHANNEL_TRAIN'] if 'SM_CHANNEL_TRAIN' in keys else args.data_dir
print("dataset dir: ", args.data_dir)
# Set seed to reduce randomness
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
args.distributed = False
'''
if args.distributed:
# SMDataParallel: Pin each GPU to a single SMDataParallel process.
torch.cuda.set_device(args.local_rank)
# torch.distributed.init_process_group(
# backend="nccl", init_method="env://"
# )
#synchronize()
'''
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.DTYPE=args.dtype
# grab checkpoint file to start from
os.system(f"aws s3 cp {args.spot_ckpt} /opt/ml/checkpoints/{args.spot_ckpt.split('/')[-1]}")
cfg.MODEL.WEIGHT = f"/opt/ml/checkpoints/{args.spot_ckpt.split('/')[-1]}"
cfg.freeze()
print ("CONFIG")
print (cfg)
output_dir = cfg.OUTPUT_DIR
if output_dir:
mkdir(output_dir)
logger = setup_logger("maskrcnn_benchmark", output_dir, 0)
logger.info(args)
logger.info("Collecting env info (might take some time)")
logger.info("\n" + collect_env_info())
logger.info("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, "r") as cf:
config_str = "\n" + cf.read()
logger.info(config_str)
logger.info("Running with config:\n{}".format(cfg))
model = train(cfg, args)
if not args.skip_test:
if not cfg.PER_EPOCH_EVAL:
test_model(cfg, model, args)