in classy_vision/tasks/classification_task.py [0:0]
def prepare(self):
"""Prepares task for training, populates all derived attributes"""
self.phases = self._build_phases()
self.train = False if self.test_only else self.train
if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH:
self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(self.base_model)
elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX:
sync_bn_process_group = apex.parallel.create_syncbn_process_group(
self.batch_norm_sync_group_size
)
self.base_model = apex.parallel.convert_syncbn_model(
self.base_model, process_group=sync_bn_process_group
)
# move the model and loss to the right device
if self.use_gpu:
self.base_model, self.base_loss = copy_model_to_gpu(
self.base_model, self.base_loss
)
else:
self.base_loss.cpu()
self.base_model.cpu()
if self.optimizer is not None:
self.prepare_optimizer(
optimizer=self.optimizer, model=self.base_model, loss=self.base_loss
)
if self.amp_args is not None:
if self.amp_type == AmpType.APEX:
# Initialize apex.amp. This updates the model and the PyTorch optimizer (
# if training, which is wrapped by the ClassyOptimizer in self.optimizer).
# Please note this must happen before loading the checkpoint, cause
# there's amp state to be restored.
if self.optimizer is None:
self.base_model = apex.amp.initialize(
self.base_model, optimizers=None, **self.amp_args
)
else:
self.base_model, self.optimizer.optimizer = apex.amp.initialize(
self.base_model, self.optimizer.optimizer, **self.amp_args
)
if self.simulated_global_batchsize is not None:
if self.simulated_global_batchsize % self.get_global_batchsize() != 0:
raise ValueError(
f"Global batch size ({self.get_global_batchsize()}) must divide "
f"simulated_global_batchsize ({self.simulated_global_batchsize})"
)
else:
self.simulated_global_batchsize = self.get_global_batchsize()
self.optimizer_period = (
self.simulated_global_batchsize // self.get_global_batchsize()
)
if self.optimizer_period > 1:
logging.info(
f"Using gradient accumulation with a period of {self.optimizer_period}"
)
if self.checkpoint_path:
self.checkpoint_dict = load_and_broadcast_checkpoint(self.checkpoint_path)
classy_state_dict = (
None
if self.checkpoint_dict is None
else self.checkpoint_dict["classy_state_dict"]
)
if classy_state_dict is not None:
state_load_success = update_classy_state(self, classy_state_dict)
assert (
state_load_success
), "Update classy state from checkpoint was unsuccessful."
self.init_distributed_data_parallel_model()