def prepare()

in classy_vision/tasks/classification_task.py [0:0]


    def prepare(self):
        """Prepares task for training, populates all derived attributes"""

        self.phases = self._build_phases()
        self.train = False if self.test_only else self.train

        if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH:
            self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(self.base_model)
        elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX:
            sync_bn_process_group = apex.parallel.create_syncbn_process_group(
                self.batch_norm_sync_group_size
            )
            self.base_model = apex.parallel.convert_syncbn_model(
                self.base_model, process_group=sync_bn_process_group
            )

        # move the model and loss to the right device
        if self.use_gpu:
            self.base_model, self.base_loss = copy_model_to_gpu(
                self.base_model, self.base_loss
            )
        else:
            self.base_loss.cpu()
            self.base_model.cpu()

        if self.optimizer is not None:
            self.prepare_optimizer(
                optimizer=self.optimizer, model=self.base_model, loss=self.base_loss
            )

        if self.amp_args is not None:
            if self.amp_type == AmpType.APEX:
                # Initialize apex.amp. This updates the model and the PyTorch optimizer (
                # if training, which is wrapped by the ClassyOptimizer in self.optimizer).
                # Please note this must happen before loading the checkpoint, cause
                # there's amp state to be restored.
                if self.optimizer is None:
                    self.base_model = apex.amp.initialize(
                        self.base_model, optimizers=None, **self.amp_args
                    )
                else:
                    self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                        self.base_model, self.optimizer.optimizer, **self.amp_args
                    )

        if self.simulated_global_batchsize is not None:
            if self.simulated_global_batchsize % self.get_global_batchsize() != 0:
                raise ValueError(
                    f"Global batch size ({self.get_global_batchsize()}) must divide "
                    f"simulated_global_batchsize ({self.simulated_global_batchsize})"
                )
        else:
            self.simulated_global_batchsize = self.get_global_batchsize()

        self.optimizer_period = (
            self.simulated_global_batchsize // self.get_global_batchsize()
        )
        if self.optimizer_period > 1:
            logging.info(
                f"Using gradient accumulation with a period of {self.optimizer_period}"
            )

        if self.checkpoint_path:
            self.checkpoint_dict = load_and_broadcast_checkpoint(self.checkpoint_path)

        classy_state_dict = (
            None
            if self.checkpoint_dict is None
            else self.checkpoint_dict["classy_state_dict"]
        )

        if classy_state_dict is not None:
            state_load_success = update_classy_state(self, classy_state_dict)
            assert (
                state_load_success
            ), "Update classy state from checkpoint was unsuccessful."

        self.init_distributed_data_parallel_model()