in easycv/toolkit/torchacc/convert_ops.py [0:0]
def _register_torchacc_ops():
global _ops_manager
reduce_op_map = {
ReduceOp.SUM: xm.REDUCE_SUM,
ReduceOp.PRODUCT: xm.REDUCE_MUL,
ReduceOp.MIN: xm.REDUCE_MIN,
ReduceOp.MAX: xm.REDUCE_MAX,
ReduceOp.BAND: xm.REDUCE_AND,
ReduceOp.BOR: xm.REDUCE_OR,
}
module_name = TORCHACC_MODULE_NAME
_ops_manager.register(module_name)
collector = _ops_manager.get_collector(module_name)
origin_to = torch.Tensor.to
origin_tensor = torch.tensor
origin_zeros = torch.zeros
torchacc_device = xm.xla_device()
from typing import Any, Optional, Union
from torch.types import _int, _bool, _dtype, _device
def torcacc_is_initialized():
# TODO: add initialize attribute
# keep consistent with torch dist behavior
return xm.xrt_world_size() > 1
def torchacc_to(self,
device: Optional[Union[_device, str]] = None,
dtype: Optional[_dtype] = None,
non_blocking: _bool = False,
copy: _bool = False) -> torch.Tensor:
if device is not None and str(device).startswith('cuda'):
device = torchacc_device
return origin_to(
self,
device=device,
dtype=dtype,
non_blocking=non_blocking,
copy=copy)
# must setattr after torchacc_to
def torchacc_cuda(self,
device: Optional[Union[_device, _int, str]] = None,
non_blocking: _bool = False) -> torch.Tensor:
assert torch.cuda.is_available()
device = torchacc_device
return self.to(device=device, non_blocking=non_blocking)
def torchacc_tensor(data: Any,
dtype: Optional[_dtype] = None,
device: Union[_device, str, None] = None,
requires_grad: _bool = False) -> torch.Tensor:
if str(device).startswith('cuda'):
device = torchacc_device
return origin_tensor(
data=data, dtype=dtype, device=device, requires_grad=requires_grad)
def torchacc_zeros(*args, device=None, **kwargs):
if str(device).startswith('cuda'):
device = torchacc_device
return origin_zeros(*args, device=device, **kwargs)
def torchacc_barrier(tag=None, **kwargs):
if tag is None:
tag = DEFAULT_TAG
return xm.rendezvous(tag, **kwargs)
def torcacc_all_reduce(tensor, op=ReduceOp.SUM, **kwargs):
if not isinstance(tensor, list):
return xm.all_reduce(
reduce_type=reduce_op_map[op], inputs=[tensor], **kwargs)
else:
return xm.all_reduce(
reduce_type=reduce_op_map[op], inputs=tensor, **kwargs)
def torchacc_reduce(tensor, dst, op=ReduceOp.SUM, **kwargs):
# if tensor.device.type != 'gpu':
# return
if xm.get_ordinal() != dst:
tensor = tensor.clone()
xm.all_reduce(
reduce_type=reduce_op_map[op], inputs=tensor, **kwargs)
else:
xm.all_reduce(
reduce_type=reduce_op_map[op], inputs=[tensor], **kwargs)
def torcacc_broadcast(**kwargs):
raise ValueError('Not support broadcast for torchacc yet!')
def torcacc_all_gather(tensor_list, tensor, **kwargs):
if len(tensor.size()) == 0:
raise ValueError(
'Not support ``all_gather`` scaler type for torchacc!')
res = xm.all_gather(value=tensor, dim=0, **kwargs)
splits = torch.tensor_split(res, len(tensor_list))
for i in range(len(tensor_list)):
assert splits[i].size() == tensor.size(
), 'mismatch size: {}, {}'.format(splits[i].size(), tensor.size())
tensor_list[i] = splits[i]
del splits
collector.add_op(
'TO', OpSpec(module=None, name=None,
value=torchacc_to)) # without `to` function, module=None
collector.add_op('CUDA', OpSpec(
module=None, name=None,
value=torchacc_cuda)) # without `cuda` function, module=None
collector.add_op('TENSOR',
OpSpec(module=None, name=None, value=torchacc_tensor))
collector.add_op('ZEROS',
OpSpec(module=None, name=None, value=torchacc_zeros))
collector.add_op(
'GET_RANK',
OpSpec(module=xm, name='get_ordinal', value=xm.get_ordinal))
collector.add_op(
'GET_WORLD_SIZE',
OpSpec(module=xm, name='xrt_world_size', value=xm.xrt_world_size))
collector.add_op(
'BARRIER',
OpSpec(module=xm, name='rendezvous', value=torchacc_barrier))
collector.add_op(
'ALL_REDUCE',
OpSpec(module=xm, name='all_reduce', value=torcacc_all_reduce))
collector.add_op('REDUCE',
OpSpec(module=None, name=None, value=torchacc_reduce))
collector.add_op('BROADCAST',
OpSpec(module=None, name=None, value=torcacc_broadcast)
) # without `broadcast` function, module=None
collector.add_op(
'ALL_GATHER',
OpSpec(module=xm, name='all_gather', value=torcacc_all_gather))
collector.add_op(
'IS_INITIALIZED',
OpSpec(module=None, name=None, value=torcacc_is_initialized))
collector.add_op(
'ADAM', OpSpec(module=xla_optim, name='Adam', value=xla_optim.Adam))
collector.add_op(
'ADAMW', OpSpec(module=xla_optim, name='AdamW', value=xla_optim.AdamW))
collector.add_op('SGD',
OpSpec(module=xla_optim, name='SGD', value=xla_optim.SGD))
collector.add_op(
'GRADSCALER',
OpSpec(
module=torchacc_amp,
name='GradScaler',
value=torchacc_amp.GradScaler))
return collector