in classy_vision/tasks/classification_task.py [0:0]
def set_amp_args(self, amp_args: Optional[Dict[str, Any]]):
"""Disable / enable apex.amp and set the automatic mixed precision parameters.
apex.amp can be utilized for mixed / half precision training.
Args:
amp_args: Dictionary containing arguments to be passed to
amp.initialize. Set to None to disable amp. To enable mixed
precision training, pass amp_args={"opt_level": "O1"} here.
See https://nvidia.github.io/apex/amp.html for more info.
Raises:
RuntimeError: If opt_level is not None and apex is not installed.
Warning: apex needs to be installed to utilize this feature.
"""
self.amp_args = amp_args
if amp_args is None:
logging.info("AMP disabled")
else:
# Check that the requested AMP type is known
try:
self.amp_type = AmpType[self.amp_args["amp_type"].upper()]
except KeyError:
logging.info("AMP type not specified, defaulting to Apex")
self.amp_type = AmpType.APEX
# Check for CUDA availability, required for both Apex and Pytorch AMP
if not torch.cuda.is_available():
raise RuntimeError(
"AMP is required but CUDA is not supported, cannot enable AMP"
)
# Check for Apex availability
if self.amp_type == AmpType.APEX and not apex_available:
raise RuntimeError(
"Apex AMP is required but Apex is not installed, cannot enable AMP"
)
if self.use_sharded_ddp:
if self.amp_type == AmpType.APEX:
raise RuntimeError(
"ShardedDDP has been requested, which is incompatible with Apex AMP"
)
if not fairscale_available:
raise RuntimeError(
"ShardedDDP has been requested, but fairscale is not installed in the current environment"
)
# Set Torch AMP grad scaler, used to prevent gradient underflow
elif self.amp_type == AmpType.PYTORCH:
if self.use_sharded_ddp:
logging.info("Using ShardedGradScaler to manage Pytorch AMP")
self.amp_grad_scaler = ShardedGradScaler()
else:
self.amp_grad_scaler = TorchGradScaler()
logging.info(f"AMP enabled with args {amp_args}")
return self