def _setup_devices()

in optimum/habana/transformers/training_args.py [0:0]


    def _setup_devices(self) -> "torch.device":
        requires_backends(self, ["torch"])

        # Hack to make sure bf16/fp32 ops are specified before calling habana_frameworks.torch.core
        if self.gaudi_config_name is not None:
            gaudi_config = GaudiConfig.from_pretrained(self.gaudi_config_name)
            if (
                (self.bf16 or gaudi_config.use_torch_autocast)
                and not self.deepspeed
                and self.half_precision_backend == "hpu_amp"
            ):
                gaudi_config.declare_autocast_bf16_fp32_ops()

        if self.sdp_on_bf16:
            torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)

        if self.inline_inbuilt_nn_modules is not None:
            torch._dynamo.config.inline_inbuilt_nn_modules = self.inline_inbuilt_nn_modules

        if self.torch_compile and self.cache_size_limit is not None:
            torch._dynamo.config.cache_size_limit = self.cache_size_limit

        if self.allow_unspec_int_on_nn_module is not None:
            torch._dynamo.config.allow_unspec_int_on_nn_module = self.allow_unspec_int_on_nn_module

        logger.info("PyTorch: setting up devices")
        if not is_accelerate_available():
            raise ImportError(
                f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: "
                f"Please run `pip install transformers[torch]` or `pip install accelerate -U`"
            )
        # We delay the init of `PartialState` to the end for clarity
        accelerator_state_kwargs = {"enabled": True, "use_configured_state": False}
        if isinstance(self.accelerator_config, AcceleratorConfig):
            accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop(
                "use_configured_state", False
            )
        if accelerator_state_kwargs["use_configured_state"]:
            if PartialState._shared_state == {}:
                raise ValueError(
                    "Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured "
                    "`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. "
                )
            # We rely on `PartialState` to yell if there's issues here (which it will)
            self.distributed_state = PartialState(cpu=self.use_cpu)
            if self.deepspeed and self.distributed_state.distributed_type != DistributedType.DEEPSPEED:
                raise RuntimeError(
                    "Tried to use an already configured `Accelerator` or `PartialState` that was not initialized for DeepSpeed, "
                    "but also passed in a `deepspeed` configuration to the `TrainingArguments`. Please set "
                    "`use_configured_state:False` instead or setup your `Accelerator` or `PartialState` properly."
                )
        else:
            AcceleratorState._reset_state(reset_partial_state=True)
            self.distributed_state = None

        # Set the log level here for optimum.utils.logging
        # otherwise logs are not sent in this method.
        log_level = self.get_process_log_level()
        logging.set_verbosity(log_level)

        if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ:
            os.environ["ACCELERATE_USE_IPEX"] = "false"

        if self.minimize_memory:
            os.environ["PT_HPU_FP8_MINIMIZE_MEMORY"] = "true"

        self._n_gpu = 1
        if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")):
            accelerator_state_kwargs["cpu"] = True
            self._n_gpu = 0
        elif self.use_habana:
            # Some methods needs to be tweaked to optimally run on Gaudi
            # Calling this method here to be sure it is done before model instantiation
            # Otherwise this will fail when some __init__ methods are overridden (cf. GPT2Attention)
            from .modeling_utils import adapt_transformers_to_gaudi

            adapt_transformers_to_gaudi()

            if self.use_lazy_mode:
                logger.info("Enabled lazy mode.")
            elif not self.torch_compile:
                if os.getenv("PT_HPU_LAZY_MODE", "1") != "0":
                    raise ValueError(
                        "Lazy mode or compile mode not enabled => eager mode should be enabled using PT_HPU_LAZY_MODE=0"
                    )

            accelerator_state_kwargs["cpu"] = False
            accelerator_state_kwargs["use_deepspeed"] = self.deepspeed
            accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout)
        else:
            raise ValueError(
                "No device has been set. Use either --use_habana to run on HPU or --use_cpu to run on CPU."
            )

        # Now we pop everything
        if accelerator_state_kwargs.pop("enabled", False) and not accelerator_state_kwargs.pop(
            "use_configured_state", False
        ):
            # We need to patch this env var when enabling to detect deepspeed
            use_deepspeed = accelerator_state_kwargs.pop("use_deepspeed", False)
            if use_deepspeed:
                os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
            self.distributed_state = PartialState(**accelerator_state_kwargs)
            if use_deepspeed:
                del os.environ["ACCELERATE_USE_DEEPSPEED"]

        # Sequence parallelism
        if self.parallel_mode == ParallelMode.DISTRIBUTED:
            if parallel_state.is_unitialized():
                parallel_state.initialize_model_parallel(
                    sequence_parallel_size=self.context_parallel_size, use_fp8=False
                )
            else:
                if parallel_state.get_sequence_parallel_world_size() != self.context_parallel_size:
                    raise ValueError(
                        "The initialized sequence parallel world size does not match the context parallel size."
                    )
                if parallel_state.amax_reduction_is_initialized():
                    logger.info("FP8 amax reduction group is already initialized.")

        device = self.distributed_state.device
        self.local_rank = self.distributed_state.local_process_index

        if dist.is_available() and dist.is_initialized() and self.parallel_mode != ParallelMode.DISTRIBUTED:
            logger.warning(
                "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
                "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
            )

        if self.distributed_state.distributed_type == DistributedType.NO:
            self._n_gpu = 0

        return device