def main()

in src/sagemaker_defect_detection/detector.py [0:0]


def main(args: Namespace) -> None:
    ddn = DDNDetection(**vars(args))

    if args.seed is not None:
        pl.seed_everything(args.seed)  # doesn't do multi-gpu
        if torch.cuda.device_count() > 1:
            torch.cuda.manual_seed_all(args.seed)

    # TODO: add deterministic training
    # torch.backends.cudnn.deterministic = True

    if ddn.train_rpn:
        checkpoint_callback = ModelCheckpoint(
            filepath=os.path.join(args.save_path, "{epoch}-{loss:.3f}"),
            save_top_k=1,
            verbose=True,
            monitor="loss",
            mode="min",
        )
        early_stop_callback = None

    elif ddn.train_roi:
        checkpoint_callback = ModelCheckpoint(
            filepath=os.path.join(args.save_path, "{epoch}-{loss:.3f}"),
            save_top_k=1,
            verbose=True,
            monitor="loss",
            mode="min",
        )
        early_stop_callback = None

    else:
        checkpoint_callback = ModelCheckpoint(
            filepath=os.path.join(args.save_path, "{epoch}-{loss:.3f}-{main_score:.3f}"),
            save_top_k=1,
            verbose=True,
            monitor="main_score",
            mode="max",
        )
        early_stop_callback = EarlyStopping("main_score", patience=50, mode="max")

    trainer = pl.Trainer(
        default_root_dir=args.save_path,
        num_sanity_val_steps=1,
        limit_val_batches=1.0,
        gpus=args.gpus,
        max_epochs=args.epochs,
        early_stop_callback=early_stop_callback,
        checkpoint_callback=checkpoint_callback,
        distributed_backend=args.distributed_backend or None,
        # precision=16 if args.use_16bit else 32, # TODO: apex
        weights_summary="top",
        resume_from_checkpoint=None if args.resume_from_checkpoint == "" else args.resume_from_checkpoint,
    )

    trainer.fit(ddn)
    return