in activemri/experimental/cvpr19_models/trainer.py [0:0]
def __call__(self) -> float:
self.logger = logging.getLogger()
if self.options.debug:
self.logger.setLevel(logging.DEBUG)
else:
self.logger.setLevel(logging.INFO)
fh = logging.FileHandler(
os.path.join(self.options.checkpoints_dir, "trainer.log")
)
formatter = logging.Formatter(
"%(asctime)s - %(threadName)s - %(levelname)s: %(message)s"
)
fh.setFormatter(formatter)
self.logger.addHandler(fh)
self.logger.info("Creating trainer with the following options:")
for key, value in vars(self.options).items():
if key == "device":
value = value.type
elif key == "gpu_ids":
value = "cuda : " + str(value) if torch.cuda.is_available() else "cpu"
self.logger.info(f" {key:>25}: {'None' if value is None else value:<30}")
# Create Reconstructor Model
self.reconstructor = models.reconstruction.ReconstructorNetwork(
number_of_cascade_blocks=self.options.number_of_cascade_blocks,
n_downsampling=self.options.n_downsampling,
number_of_filters=self.options.number_of_reconstructor_filters,
number_of_layers_residual_bottleneck=self.options.number_of_layers_residual_bottleneck,
mask_embed_dim=self.options.mask_embed_dim,
dropout_probability=self.options.dropout_probability,
img_width=self.options.image_width,
use_deconv=self.options.use_deconv,
)
if self.options.device.type == "cuda":
self.reconstructor = torch.nn.DataParallel(self.reconstructor).to(
self.options.device
)
self.optimizers = {
"G": optim.Adam(
self.reconstructor.parameters(),
lr=self.options.lr,
betas=(self.options.beta1, 0.999),
)
}
# Create Evaluator Model
if self.options.use_evaluator:
self.evaluator = models.evaluator.EvaluatorNetwork(
number_of_filters=self.options.number_of_evaluator_filters,
number_of_conv_layers=self.options.number_of_evaluator_convolution_layers,
use_sigmoid=False,
width=self.options.image_width,
height=640 if self.options.dataroot == "KNEE_RAW" else None,
mask_embed_dim=self.options.mask_embed_dim,
)
self.evaluator = torch.nn.DataParallel(self.evaluator).to(
self.options.device
)
self.optimizers["D"] = optim.Adam(
self.evaluator.parameters(),
lr=self.options.lr,
betas=(self.options.beta1, 0.999),
)
train_loader, val_loader = self.get_loaders()
self.load_from_checkpoint_if_present()
self.load_weights_from_given_checkpoint()
writer = SummaryWriter(self.options.checkpoints_dir)
# Training engine and handlers
train_engine = Engine(lambda engine, batch: self.update(batch))
val_engine = Engine(lambda engine, batch: self.inference(batch))
validation_mse = Loss(
loss_fn=F.mse_loss,
output_transform=lambda x: (
x["reconstructed_image_magnitude"],
x["ground_truth_magnitude"],
),
)
validation_mse.attach(val_engine, name="mse")
validation_ssim = Loss(
loss_fn=util.common.compute_ssims,
output_transform=lambda x: (
x["reconstructed_image_magnitude"],
x["ground_truth_magnitude"],
),
)
validation_ssim.attach(val_engine, name="ssim")
if self.options.use_evaluator:
validation_loss_d = Loss(
loss_fn=self.discriminator_loss,
output_transform=lambda x: (
x["reconstructor_eval"],
x["ground_truth_eval"],
{
"reconstructed_image": x["reconstructed_image"],
"target": x["ground_truth"],
"mask": x["mask"],
},
),
)
validation_loss_d.attach(val_engine, name="loss_D")
progress_bar = ProgressBar()
progress_bar.attach(train_engine)
train_engine.add_event_handler(
Events.EPOCH_COMPLETED,
run_validation_and_update_best_checkpoint,
val_engine=val_engine,
progress_bar=progress_bar,
val_loader=val_loader,
trainer=self,
)
# Tensorboard Plots
@train_engine.on(Events.ITERATION_COMPLETED)
def plot_training_loss(engine):
writer.add_scalar(
"training/generator_loss",
engine.state.output["loss_G"],
self.updates_performed,
)
if "loss_D" in engine.state.output:
writer.add_scalar(
"training/discriminator_loss",
engine.state.output["loss_D"],
self.updates_performed,
)
@train_engine.on(Events.EPOCH_COMPLETED)
def plot_validation_loss(_):
writer.add_scalar(
"validation/MSE", val_engine.state.metrics["mse"], self.completed_epochs
)
writer.add_scalar(
"validation/SSIM",
val_engine.state.metrics["ssim"],
self.completed_epochs,
)
if "loss_D" in val_engine.state.metrics:
writer.add_scalar(
"validation/loss_D",
val_engine.state.metrics["loss_D"],
self.completed_epochs,
)
@train_engine.on(Events.EPOCH_COMPLETED)
def plot_validation_images(_):
ground_truth = val_engine.state.output["ground_truth_magnitude"]
zero_filled_image = val_engine.state.output["zero_filled_image_magnitude"]
reconstructed_image = val_engine.state.output[
"reconstructed_image_magnitude"
]
uncertainty_map = val_engine.state.output["uncertainty_map"]
difference = torch.abs(ground_truth - reconstructed_image)
# Create plots
ground_truth = util.common.create_grid_from_tensor(ground_truth)
writer.add_image(
"validation_images/ground_truth", ground_truth, self.completed_epochs
)
zero_filled_image = util.common.create_grid_from_tensor(zero_filled_image)
writer.add_image(
"validation_images/zero_filled_image",
zero_filled_image,
self.completed_epochs,
)
reconstructed_image = util.common.create_grid_from_tensor(
reconstructed_image
)
writer.add_image(
"validation_images/reconstructed_image",
reconstructed_image,
self.completed_epochs,
)
uncertainty_map = util.common.gray2heatmap(
util.common.create_grid_from_tensor(uncertainty_map.exp()),
cmap="jet",
)
writer.add_image(
"validation_images/uncertainty_map",
uncertainty_map,
self.completed_epochs,
)
difference = util.common.create_grid_from_tensor(difference)
difference = util.common.gray2heatmap(difference, cmap="gray")
writer.add_image(
"validation_images/difference", difference, self.completed_epochs
)
mask = util.common.create_grid_from_tensor(
val_engine.state.output["mask"].repeat(
1, 1, val_engine.state.output["mask"].shape[3], 1
)
)
writer.add_image(
"validation_images/mask_image", mask, self.completed_epochs
)
train_engine.add_event_handler(
Events.EPOCH_COMPLETED,
save_regular_checkpoint,
trainer=self,
progress_bar=progress_bar,
)
train_engine.run(train_loader, self.options.max_epochs - self.completed_epochs)
writer.close()
return self.best_validation_score