in tutorials/e2e-distributed-pytorch-image/src/pytorch_dl_train/train.py [0:0]
def setup_config(self, args):
"""Sets internal variables using provided CLI arguments (see build_arguments_parser()).
In particular, sets device(cuda) and multinode parameters."""
self.dataloading_config = args
self.training_config = args
# verify parameter default values
if self.dataloading_config.num_workers is None:
self.dataloading_config.num_workers = 0
if self.dataloading_config.num_workers < 0:
self.dataloading_config.num_workers = os.cpu_count()
if self.dataloading_config.num_workers == 0:
self.logger.warning(
"You specified num_workers=0, forcing prefetch_factor to be discarded."
)
self.dataloading_config.prefetch_factor = None
# NOTE: strtobool returns an int, converting to bool explicitely
self.dataloading_config.pin_memory = bool(self.dataloading_config.pin_memory)
self.dataloading_config.non_blocking = bool(
self.dataloading_config.non_blocking
)
# DISTRIBUTED: detect multinode config
# depending on the Azure ML distribution.type, different environment variables will be provided
# to configure DistributedDataParallel
self.distributed_backend = args.distributed_backend
if self.distributed_backend == "nccl":
self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
self.world_rank = int(os.environ.get("RANK", "0"))
self.local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
self.multinode_available = self.world_size > 1
self.self_is_main_node = self.world_rank == 0
elif self.distributed_backend == "mpi":
# Note: Distributed pytorch package doesn't have MPI built in.
# MPI is only included if you build PyTorch from source on a host that has MPI installed.
self.world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", "1"))
self.world_rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", "0"))
self.local_world_size = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_SIZE", "1")
)
self.local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
self.multinode_available = self.world_size > 1
self.self_is_main_node = self.world_rank == 0
else:
raise NotImplementedError(
f"distributed_backend={self.distributed_backend} is not implemented yet."
)
# Use CUDA if it is available
if torch.cuda.is_available():
self.logger.info(
f"Setting up torch.device for CUDA for local gpu:{self.local_rank}"
)
self.device = torch.device(self.local_rank)
else:
self.logger.info(f"Setting up torch.device for cpu")
self.device = torch.device("cpu")
if self.multinode_available:
self.logger.info(
f"Running in multinode with backend={self.distributed_backend} local_rank={self.local_rank} rank={self.world_rank} size={self.world_size}"
)
# DISTRIBUTED: this is required to initialize the pytorch backend
torch.distributed.init_process_group(
self.distributed_backend,
rank=self.world_rank,
world_size=self.world_size,
)
else:
self.logger.info(f"Not running in multinode.")
# DISTRIBUTED: in distributed mode, you want to report parameters
# only from main process (rank==0) to avoid conflict
if self.self_is_main_node:
# MLFLOW: report relevant parameters using mlflow
mlflow.log_params(
{
# log some distribution params
"nodes": self.world_size // self.local_world_size,
"instance_per_node": self.local_world_size,
"cuda_available": torch.cuda.is_available(),
"cuda_device_count": torch.cuda.device_count(),
"distributed": self.multinode_available,
"distributed_backend": self.distributed_backend,
# data loading params
"batch_size": self.dataloading_config.batch_size,
"num_workers": self.dataloading_config.num_workers,
"prefetch_factor": self.dataloading_config.prefetch_factor,
"pin_memory": self.dataloading_config.pin_memory,
"non_blocking": self.dataloading_config.non_blocking,
# training params
"model_arch": self.training_config.model_arch,
"model_arch_pretrained": self.training_config.model_arch_pretrained,
"learning_rate": self.training_config.learning_rate,
"num_epochs": self.training_config.num_epochs,
# profiling params
"enable_profiling": self.training_config.enable_profiling,
}
)