in classy_vision/tasks/classification_task.py [0:0]
def set_classy_state(self, state):
"""Set task state
Args:
state: Dict containing state of a task
"""
self.train = False if self.test_only else state["train"]
self.base_model.set_classy_state(state["base_model"])
if self.test_only:
# if we're only testing, just need the state of the model to be updated
return
self.phase_idx = state["phase_idx"]
self.num_updates = state["num_updates"]
self.train_phase_idx = state["train_phase_idx"]
self.losses = state["losses"]
for meter, meter_state in zip(self.meters, state["meters"]):
meter.set_classy_state(meter_state)
if self.optimizer is not None:
self.optimizer.set_classy_state(state["optimizer"])
if state.get("loss") and isinstance(self.base_loss, ClassyLoss):
self.base_loss.set_classy_state(state["loss"])
if "amp" in state:
if self.amp_type == AmpType.APEX:
apex.amp.load_state_dict(state["amp"])
else:
self.amp_grad_scaler.load_state_dict(state["amp"])
for hook in self.hooks:
# we still want to be able to run when new hooks are added or old
# hooks are removed
if hook.name() in state["hooks"]:
hook.set_classy_state(state["hooks"][hook.name()])
else:
logging.warning(f"No state found for hook: {hook.name()}")
if "train" in self.datasets and self._is_checkpointable_dataset(
self.datasets["train"]
):
self.datasets["train"].set_classy_state(state.get("train_dataset_iterator"))