def set_amp_args()

in classy_vision/tasks/classification_task.py [0:0]


    def set_amp_args(self, amp_args: Optional[Dict[str, Any]]):
        """Disable / enable apex.amp and set the automatic mixed precision parameters.

        apex.amp can be utilized for mixed / half precision training.

        Args:
            amp_args: Dictionary containing arguments to be passed to
            amp.initialize. Set to None to disable amp.  To enable mixed
            precision training, pass amp_args={"opt_level": "O1"} here.
            See https://nvidia.github.io/apex/amp.html for more info.

        Raises:
            RuntimeError: If opt_level is not None and apex is not installed.

        Warning: apex needs to be installed to utilize this feature.
        """
        self.amp_args = amp_args

        if amp_args is None:
            logging.info("AMP disabled")
        else:
            # Check that the requested AMP type is known
            try:
                self.amp_type = AmpType[self.amp_args["amp_type"].upper()]
            except KeyError:
                logging.info("AMP type not specified, defaulting to Apex")
                self.amp_type = AmpType.APEX

            # Check for CUDA availability, required for both Apex and Pytorch AMP
            if not torch.cuda.is_available():
                raise RuntimeError(
                    "AMP is required but CUDA is not supported, cannot enable AMP"
                )

            # Check for Apex availability
            if self.amp_type == AmpType.APEX and not apex_available:
                raise RuntimeError(
                    "Apex AMP is required but Apex is not installed, cannot enable AMP"
                )

            if self.use_sharded_ddp:
                if self.amp_type == AmpType.APEX:
                    raise RuntimeError(
                        "ShardedDDP has been requested, which is incompatible with Apex AMP"
                    )

                if not fairscale_available:
                    raise RuntimeError(
                        "ShardedDDP has been requested, but fairscale is not installed in the current environment"
                    )

            # Set Torch AMP grad scaler, used to prevent gradient underflow
            elif self.amp_type == AmpType.PYTORCH:

                if self.use_sharded_ddp:
                    logging.info("Using ShardedGradScaler to manage Pytorch AMP")
                    self.amp_grad_scaler = ShardedGradScaler()
                else:
                    self.amp_grad_scaler = TorchGradScaler()

            logging.info(f"AMP enabled with args {amp_args}")
        return self