def prepare()

in classy_vision/tasks/fine_tuning_task.py [0:0]


    def prepare(self) -> None:
        super().prepare()
        if self.checkpoint_dict is None:
            # no checkpoint exists, load the model's state from the pretrained
            # checkpoint

            if self.pretrained_checkpoint_path:
                self.pretrained_checkpoint_dict = load_and_broadcast_checkpoint(
                    self.pretrained_checkpoint_path
                )

            assert (
                self.pretrained_checkpoint_dict is not None
            ), "Need a pretrained checkpoint for fine tuning"

            state_load_success = update_classy_model(
                self.base_model,
                self.pretrained_checkpoint_dict["classy_state_dict"]["base_model"],
                self.reset_heads,
                self.pretrained_checkpoint_load_strict,
            )
            assert (
                state_load_success
            ), "Update classy state from pretrained checkpoint was unsuccessful."

        if self.freeze_trunk:
            # do not track gradients for all the parameters in the model except
            # for the parameters in the heads
            for param in self.base_model.parameters():
                param.requires_grad = False
            for heads in self.base_model.get_heads().values():
                for h in heads:
                    for param in h.parameters():
                        param.requires_grad = True
            # re-create ddp model
            self.distributed_model = None
            self.init_distributed_data_parallel_model()