in classy_vision/tasks/fine_tuning_task.py [0:0]
def prepare(self) -> None:
super().prepare()
if self.checkpoint_dict is None:
# no checkpoint exists, load the model's state from the pretrained
# checkpoint
if self.pretrained_checkpoint_path:
self.pretrained_checkpoint_dict = load_and_broadcast_checkpoint(
self.pretrained_checkpoint_path
)
assert (
self.pretrained_checkpoint_dict is not None
), "Need a pretrained checkpoint for fine tuning"
state_load_success = update_classy_model(
self.base_model,
self.pretrained_checkpoint_dict["classy_state_dict"]["base_model"],
self.reset_heads,
self.pretrained_checkpoint_load_strict,
)
assert (
state_load_success
), "Update classy state from pretrained checkpoint was unsuccessful."
if self.freeze_trunk:
# do not track gradients for all the parameters in the model except
# for the parameters in the heads
for param in self.base_model.parameters():
param.requires_grad = False
for heads in self.base_model.get_heads().values():
for h in heads:
for param in h.parameters():
param.requires_grad = True
# re-create ddp model
self.distributed_model = None
self.init_distributed_data_parallel_model()