in classy_vision/tasks/classification_task.py [0:0]
def get_classy_state(self, deep_copy: bool = False):
"""Returns serialiable state of task
Args:
deep_copy: If true, does a deep copy of state before returning.
"""
optimizer_state = {}
if self.optimizer is not None:
optimizer_state = self.optimizer.get_classy_state()
classy_state_dict = {
"train": self.train,
"base_model": self.base_model.get_classy_state(),
"meters": [meter.get_classy_state() for meter in self.meters],
"optimizer": optimizer_state,
"phase_idx": self.phase_idx,
"train_phase_idx": self.train_phase_idx,
"num_updates": self.num_updates,
"losses": self.losses,
"hooks": {hook.name(): hook.get_classy_state() for hook in self.hooks},
"loss": {},
}
if "train" in self.datasets and self._is_checkpointable_dataset(
self.datasets["train"]
):
classy_state_dict["train_dataset_iterator"] = self.datasets[
"train"
].get_classy_state()
if isinstance(self.base_loss, ClassyLoss):
classy_state_dict["loss"] = self.base_loss.get_classy_state()
if self.amp_args is not None:
if self.amp_type == AmpType.APEX:
classy_state_dict["amp"] = apex.amp.state_dict()
elif self.amp_grad_scaler is not None:
classy_state_dict["amp"] = self.amp_grad_scaler.state_dict()
if deep_copy:
classy_state_dict = copy.deepcopy(classy_state_dict)
return classy_state_dict