in optimum/neuron/generation/utils.py [0:0]
def _move_dict_args_to_device(kwargs: Dict[str, Any], device: str = "cpu") -> Dict[str, Any]:
"""
Takes keyword arguments which will be passed to a model's forward function
and moves its values to `device` if
they are of type `torch.Tensor`. If the key is a dictionary it does the same to the
respective values.
Args:
kwargs: (`Dict[str, Any]`):
The kwargs to be passed to the models forward function.
device: (`str`, defaults to `cpu`):
The target device to which tensors should be moved.
Returns:
`Dict[str, Any]`: The kwargs dict with its tensors moved to `device`.
"""
def needs_move(src_device, tgt_device):
return src_device != tgt_device
for k, v in kwargs.items():
# Handle nested dicts
if isinstance(v, dict):
for k_, v_ in v.items():
if isinstance(v_, torch.Tensor):
if needs_move(v_.device, device):
v[k_] = v_.to(device=device)
# Handle tensor types
elif isinstance(v, torch.Tensor):
if needs_move(v.device, device):
kwargs[k] = v.to(device=device)
# Handle past_key_value tuples
elif k == "past_key_values":
if v is not None:
new_past_key_values = ()
for layer_past in v:
new_layer_past = ()
for past_state in layer_past:
if needs_move(past_state.device, device):
new_layer_past += (past_state.to(device=device),)
else:
new_layer_past += (past_state,)
new_past_key_values += (new_layer_past,)
kwargs[k] = new_past_key_values
return kwargs