in src/sagemaker_defect_detection/classifier.py [0:0]
def main(args: Namespace) -> None:
model = DDNClassification(**vars(args))
if args.seed is not None:
pl.seed_everything(args.seed)
if torch.cuda.device_count() > 1:
torch.cuda.manual_seed_all(args.seed)
# TODO: add deterministic training
# torch.backends.cudnn.deterministic = True
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(args.save_path, "{epoch}-{val_loss:.3f}-{val_acc:.3f}"),
save_top_k=1,
verbose=True,
monitor="val_acc",
mode="max",
)
early_stop_callback = EarlyStopping("val_loss", patience=10)
trainer = pl.Trainer(
default_root_dir=args.save_path,
gpus=args.gpus,
max_epochs=args.epochs,
early_stop_callback=early_stop_callback,
checkpoint_callback=checkpoint_callback,
gradient_clip_val=10,
num_sanity_val_steps=0,
distributed_backend=args.distributed_backend or None,
# precision=16 if args.use_16bit else 32, # TODO: amp apex support
)
trainer.fit(model)
trainer.test()
return