def set_classy_state()

in classy_vision/tasks/classification_task.py [0:0]


    def set_classy_state(self, state):
        """Set task state

        Args:
            state: Dict containing state of a task
        """
        self.train = False if self.test_only else state["train"]
        self.base_model.set_classy_state(state["base_model"])

        if self.test_only:
            # if we're only testing, just need the state of the model to be updated
            return

        self.phase_idx = state["phase_idx"]
        self.num_updates = state["num_updates"]
        self.train_phase_idx = state["train_phase_idx"]
        self.losses = state["losses"]
        for meter, meter_state in zip(self.meters, state["meters"]):
            meter.set_classy_state(meter_state)

        if self.optimizer is not None:
            self.optimizer.set_classy_state(state["optimizer"])
        if state.get("loss") and isinstance(self.base_loss, ClassyLoss):
            self.base_loss.set_classy_state(state["loss"])

        if "amp" in state:
            if self.amp_type == AmpType.APEX:
                apex.amp.load_state_dict(state["amp"])
            else:
                self.amp_grad_scaler.load_state_dict(state["amp"])

        for hook in self.hooks:
            # we still want to be able to run when new hooks are added or old
            # hooks are removed
            if hook.name() in state["hooks"]:
                hook.set_classy_state(state["hooks"][hook.name()])
            else:
                logging.warning(f"No state found for hook: {hook.name()}")

        if "train" in self.datasets and self._is_checkpointable_dataset(
            self.datasets["train"]
        ):
            self.datasets["train"].set_classy_state(state.get("train_dataset_iterator"))