def _xla_gather()

in optimum/neuron/accelerate/utils/operations.py [0:0]


def _xla_gather(tensor, out_of_graph: bool = False):
    import torch_xla.core.xla_model as xm

    groups = None
    if is_neuronx_distributed_available():
        from neuronx_distributed.parallel_layers.parallel_state import (
            get_data_parallel_group,
            model_parallel_is_initialized,
        )

        if model_parallel_is_initialized():
            groups = get_data_parallel_group(as_list=True)

    def _xla_gather_one(tensor):
        if tensor.ndim == 0:
            tensor = tensor.clone()[None]
        # Can only gather contiguous tensors
        if not tensor.is_contiguous():
            tensor = tensor.contiguous()

        if out_of_graph:
            gathered_tensors = xm.mesh_reduce("nested_xla_gather", tensor, lambda x: x)
            if groups is not None:
                new_gathered_tensors = []
                # Since groups is containing list of group of replicas, we consider that visiting the first group of
                # replicas is enough since the value should be the same across other axes.
                replicas_to_consider = set(groups[0])
                for idx, tensor in enumerate(gathered_tensors):
                    if idx not in replicas_to_consider:
                        continue
                    new_gathered_tensors.append(tensor)
                gathered_tensors = new_gathered_tensors
            gathered = torch.cat(gathered_tensors)
        else:
            gathered = xm.all_gather(tensor, groups=groups, pin_layout=False)
        return gathered

    res = recursively_apply(_xla_gather_one, tensor, error_on_other_type=True)
    xm.mark_step()
    return res