def main()

in src/sagemaker_defect_detection/classifier.py [0:0]


def main(args: Namespace) -> None:
    model = DDNClassification(**vars(args))

    if args.seed is not None:
        pl.seed_everything(args.seed)
        if torch.cuda.device_count() > 1:
            torch.cuda.manual_seed_all(args.seed)

    # TODO: add deterministic training
    # torch.backends.cudnn.deterministic = True

    checkpoint_callback = ModelCheckpoint(
        filepath=os.path.join(args.save_path, "{epoch}-{val_loss:.3f}-{val_acc:.3f}"),
        save_top_k=1,
        verbose=True,
        monitor="val_acc",
        mode="max",
    )
    early_stop_callback = EarlyStopping("val_loss", patience=10)
    trainer = pl.Trainer(
        default_root_dir=args.save_path,
        gpus=args.gpus,
        max_epochs=args.epochs,
        early_stop_callback=early_stop_callback,
        checkpoint_callback=checkpoint_callback,
        gradient_clip_val=10,
        num_sanity_val_steps=0,
        distributed_backend=args.distributed_backend or None,
        # precision=16 if args.use_16bit else 32, # TODO: amp apex support
    )

    trainer.fit(model)
    trainer.test()
    return