def _register_torchacc_ops()

in easycv/toolkit/torchacc/convert_ops.py [0:0]


def _register_torchacc_ops():
    global _ops_manager

    reduce_op_map = {
        ReduceOp.SUM: xm.REDUCE_SUM,
        ReduceOp.PRODUCT: xm.REDUCE_MUL,
        ReduceOp.MIN: xm.REDUCE_MIN,
        ReduceOp.MAX: xm.REDUCE_MAX,
        ReduceOp.BAND: xm.REDUCE_AND,
        ReduceOp.BOR: xm.REDUCE_OR,
    }

    module_name = TORCHACC_MODULE_NAME
    _ops_manager.register(module_name)
    collector = _ops_manager.get_collector(module_name)

    origin_to = torch.Tensor.to
    origin_tensor = torch.tensor
    origin_zeros = torch.zeros
    torchacc_device = xm.xla_device()

    from typing import Any, Optional, Union
    from torch.types import _int, _bool, _dtype, _device

    def torcacc_is_initialized():
        # TODO: add initialize attribute
        # keep consistent with torch dist behavior
        return xm.xrt_world_size() > 1

    def torchacc_to(self,
                    device: Optional[Union[_device, str]] = None,
                    dtype: Optional[_dtype] = None,
                    non_blocking: _bool = False,
                    copy: _bool = False) -> torch.Tensor:
        if device is not None and str(device).startswith('cuda'):
            device = torchacc_device
        return origin_to(
            self,
            device=device,
            dtype=dtype,
            non_blocking=non_blocking,
            copy=copy)

    # must setattr after torchacc_to
    def torchacc_cuda(self,
                      device: Optional[Union[_device, _int, str]] = None,
                      non_blocking: _bool = False) -> torch.Tensor:
        assert torch.cuda.is_available()
        device = torchacc_device
        return self.to(device=device, non_blocking=non_blocking)

    def torchacc_tensor(data: Any,
                        dtype: Optional[_dtype] = None,
                        device: Union[_device, str, None] = None,
                        requires_grad: _bool = False) -> torch.Tensor:
        if str(device).startswith('cuda'):
            device = torchacc_device
        return origin_tensor(
            data=data, dtype=dtype, device=device, requires_grad=requires_grad)

    def torchacc_zeros(*args, device=None, **kwargs):
        if str(device).startswith('cuda'):
            device = torchacc_device
        return origin_zeros(*args, device=device, **kwargs)

    def torchacc_barrier(tag=None, **kwargs):
        if tag is None:
            tag = DEFAULT_TAG
        return xm.rendezvous(tag, **kwargs)

    def torcacc_all_reduce(tensor, op=ReduceOp.SUM, **kwargs):
        if not isinstance(tensor, list):
            return xm.all_reduce(
                reduce_type=reduce_op_map[op], inputs=[tensor], **kwargs)
        else:
            return xm.all_reduce(
                reduce_type=reduce_op_map[op], inputs=tensor, **kwargs)

    def torchacc_reduce(tensor, dst, op=ReduceOp.SUM, **kwargs):
        # if tensor.device.type != 'gpu':
        #     return
        if xm.get_ordinal() != dst:
            tensor = tensor.clone()
            xm.all_reduce(
                reduce_type=reduce_op_map[op], inputs=tensor, **kwargs)
        else:
            xm.all_reduce(
                reduce_type=reduce_op_map[op], inputs=[tensor], **kwargs)

    def torcacc_broadcast(**kwargs):
        raise ValueError('Not support broadcast for torchacc yet!')

    def torcacc_all_gather(tensor_list, tensor, **kwargs):
        if len(tensor.size()) == 0:
            raise ValueError(
                'Not support ``all_gather`` scaler type for torchacc!')

        res = xm.all_gather(value=tensor, dim=0, **kwargs)
        splits = torch.tensor_split(res, len(tensor_list))

        for i in range(len(tensor_list)):
            assert splits[i].size() == tensor.size(
            ), 'mismatch size: {}, {}'.format(splits[i].size(), tensor.size())
            tensor_list[i] = splits[i]
        del splits

    collector.add_op(
        'TO', OpSpec(module=None, name=None,
                     value=torchacc_to))  # without `to` function, module=None
    collector.add_op('CUDA', OpSpec(
        module=None, name=None,
        value=torchacc_cuda))  # without `cuda` function, module=None
    collector.add_op('TENSOR',
                     OpSpec(module=None, name=None, value=torchacc_tensor))
    collector.add_op('ZEROS',
                     OpSpec(module=None, name=None, value=torchacc_zeros))
    collector.add_op(
        'GET_RANK',
        OpSpec(module=xm, name='get_ordinal', value=xm.get_ordinal))
    collector.add_op(
        'GET_WORLD_SIZE',
        OpSpec(module=xm, name='xrt_world_size', value=xm.xrt_world_size))
    collector.add_op(
        'BARRIER',
        OpSpec(module=xm, name='rendezvous', value=torchacc_barrier))
    collector.add_op(
        'ALL_REDUCE',
        OpSpec(module=xm, name='all_reduce', value=torcacc_all_reduce))
    collector.add_op('REDUCE',
                     OpSpec(module=None, name=None, value=torchacc_reduce))
    collector.add_op('BROADCAST',
                     OpSpec(module=None, name=None, value=torcacc_broadcast)
                     )  # without `broadcast` function, module=None
    collector.add_op(
        'ALL_GATHER',
        OpSpec(module=xm, name='all_gather', value=torcacc_all_gather))
    collector.add_op(
        'IS_INITIALIZED',
        OpSpec(module=None, name=None, value=torcacc_is_initialized))
    collector.add_op(
        'ADAM', OpSpec(module=xla_optim, name='Adam', value=xla_optim.Adam))
    collector.add_op(
        'ADAMW', OpSpec(module=xla_optim, name='AdamW', value=xla_optim.AdamW))
    collector.add_op('SGD',
                     OpSpec(module=xla_optim, name='SGD', value=xla_optim.SGD))
    collector.add_op(
        'GRADSCALER',
        OpSpec(
            module=torchacc_amp,
            name='GradScaler',
            value=torchacc_amp.GradScaler))

    return collector