def __init__()

in optimum/neuron/accelerate/state.py [0:0]


    def __init__(self, cpu: bool = False, **kwargs):
        self.__dict__ = self._shared_state
        if not self.initialized:
            self._cpu = cpu
            self.backend = None
            env_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None)
            self.device = torch.device(env_device) if env_device is not None else None
            self.debug = parse_flag_from_env("ACCELERATE_DEBUG_MODE")
            use_sagemaker_dp = kwargs.pop("_use_sagemaker_dp", None)
            if use_sagemaker_dp is None:
                use_sagemaker_dp = (
                    os.environ.get("ACCELERATE_USE_SAGEMAKER", "false") == "true"
                    and os.environ.get("ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE") != SageMakerDistributedType.NO
                )

            backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, kwargs.pop("backend", None))
            self.backend = backend
            self.distributed_type = distributed_type
            # No backend == no distributed training
            if self.backend is None:
                self.distributed_type = DistributedType.NO
                self.num_processes = 1
                self.process_index = 0
                self.local_process_index = 0
            elif self.backend != "xla":
                raise ValueError("Only the XLA backend is supported by `optimum-neuron`.")
            else:
                # It is important to set the environment variables before initializing the process group otherwise they will be ignored by the Neuron compiler.
                set_common_flags()
                if os.environ.get("ACCELERATE_USE_AMP", "false") == "true":
                    set_neuron_cc_flags_for_torch_amp()
                if not torch.distributed.is_initialized():
                    init_process_group()
                self.num_processes = xr.world_size()
                self.process_index = xr.global_ordinal()
                self.local_process_index = xm.get_local_ordinal()
                self.device = xm.xla_device()

        # Important: This should be the *only* code outside of `self.initialized!`
        self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)