in training/trainer.py [0:0]
def _setup_components(self):
# Get the keys for all the val datasets, if any
val_phase = Phase.VAL
val_keys = None
if self.data_conf.get(val_phase, None) is not None:
val_keys = collect_dict_keys(self.data_conf[val_phase])
# Additional checks on the sanity of the config for val datasets
self._check_val_key_match(val_keys, phase=val_phase)
logging.info("Setting up components: Model, loss, optim, meters etc.")
self.epoch = 0
self.steps = {Phase.TRAIN: 0, Phase.VAL: 0}
self.logger = Logger(self.logging_conf)
self.model = instantiate(self.model_conf, _convert_="all")
print_model_summary(self.model)
self.loss = None
if self.loss_conf:
self.loss = {
key: el # wrap_base_loss(el)
for (key, el) in instantiate(self.loss_conf, _convert_="all").items()
}
self.loss = nn.ModuleDict(self.loss)
self.meters = {}
self.best_meter_values = {}
if self.meters_conf:
self.meters = instantiate(self.meters_conf, _convert_="all")
self.scaler = torch.amp.GradScaler(
self.device,
enabled=self.optim_conf.amp.enabled if self.optim_conf else False,
)
self.gradient_clipper = (
instantiate(self.optim_conf.gradient_clip) if self.optim_conf else None
)
self.gradient_logger = (
instantiate(self.optim_conf.gradient_logger) if self.optim_conf else None
)
logging.info("Finished setting up components: Model, loss, optim, meters etc.")