def setup_config()

in tutorials/e2e-distributed-pytorch-image/src/pytorch_dl_train/train.py [0:0]


    def setup_config(self, args):
        """Sets internal variables using provided CLI arguments (see build_arguments_parser()).
        In particular, sets device(cuda) and multinode parameters."""
        self.dataloading_config = args
        self.training_config = args

        # verify parameter default values
        if self.dataloading_config.num_workers is None:
            self.dataloading_config.num_workers = 0
        if self.dataloading_config.num_workers < 0:
            self.dataloading_config.num_workers = os.cpu_count()
        if self.dataloading_config.num_workers == 0:
            self.logger.warning(
                "You specified num_workers=0, forcing prefetch_factor to be discarded."
            )
            self.dataloading_config.prefetch_factor = None

        # NOTE: strtobool returns an int, converting to bool explicitely
        self.dataloading_config.pin_memory = bool(self.dataloading_config.pin_memory)
        self.dataloading_config.non_blocking = bool(
            self.dataloading_config.non_blocking
        )

        # DISTRIBUTED: detect multinode config
        # depending on the Azure ML distribution.type, different environment variables will be provided
        # to configure DistributedDataParallel
        self.distributed_backend = args.distributed_backend
        if self.distributed_backend == "nccl":
            self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
            self.world_rank = int(os.environ.get("RANK", "0"))
            self.local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
            self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
            self.multinode_available = self.world_size > 1
            self.self_is_main_node = self.world_rank == 0

        elif self.distributed_backend == "mpi":
            # Note: Distributed pytorch package doesn't have MPI built in.
            # MPI is only included if you build PyTorch from source on a host that has MPI installed.
            self.world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", "1"))
            self.world_rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", "0"))
            self.local_world_size = int(
                os.environ.get("OMPI_COMM_WORLD_LOCAL_SIZE", "1")
            )
            self.local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
            self.multinode_available = self.world_size > 1
            self.self_is_main_node = self.world_rank == 0

        else:
            raise NotImplementedError(
                f"distributed_backend={self.distributed_backend} is not implemented yet."
            )

        # Use CUDA if it is available
        if torch.cuda.is_available():
            self.logger.info(
                f"Setting up torch.device for CUDA for local gpu:{self.local_rank}"
            )
            self.device = torch.device(self.local_rank)
        else:
            self.logger.info(f"Setting up torch.device for cpu")
            self.device = torch.device("cpu")

        if self.multinode_available:
            self.logger.info(
                f"Running in multinode with backend={self.distributed_backend} local_rank={self.local_rank} rank={self.world_rank} size={self.world_size}"
            )
            # DISTRIBUTED: this is required to initialize the pytorch backend
            torch.distributed.init_process_group(
                self.distributed_backend,
                rank=self.world_rank,
                world_size=self.world_size,
            )
        else:
            self.logger.info(f"Not running in multinode.")

        # DISTRIBUTED: in distributed mode, you want to report parameters
        # only from main process (rank==0) to avoid conflict
        if self.self_is_main_node:
            # MLFLOW: report relevant parameters using mlflow
            mlflow.log_params(
                {
                    # log some distribution params
                    "nodes": self.world_size // self.local_world_size,
                    "instance_per_node": self.local_world_size,
                    "cuda_available": torch.cuda.is_available(),
                    "cuda_device_count": torch.cuda.device_count(),
                    "distributed": self.multinode_available,
                    "distributed_backend": self.distributed_backend,
                    # data loading params
                    "batch_size": self.dataloading_config.batch_size,
                    "num_workers": self.dataloading_config.num_workers,
                    "prefetch_factor": self.dataloading_config.prefetch_factor,
                    "pin_memory": self.dataloading_config.pin_memory,
                    "non_blocking": self.dataloading_config.non_blocking,
                    # training params
                    "model_arch": self.training_config.model_arch,
                    "model_arch_pretrained": self.training_config.model_arch_pretrained,
                    "learning_rate": self.training_config.learning_rate,
                    "num_epochs": self.training_config.num_epochs,
                    # profiling params
                    "enable_profiling": self.training_config.enable_profiling,
                }
            )