def _move_dict_args_to_device()

in optimum/neuron/generation/utils.py [0:0]


def _move_dict_args_to_device(kwargs: Dict[str, Any], device: str = "cpu") -> Dict[str, Any]:
    """
    Takes keyword arguments which will be passed to a model's forward function
    and moves its values to `device` if
    they are of type `torch.Tensor`. If the key is a dictionary it does the same to the
    respective values.
    Args:
        kwargs: (`Dict[str, Any]`):
            The kwargs to be passed to the models forward function.
        device: (`str`, defaults to `cpu`):
            The target device to which tensors should be moved.
    Returns:
        `Dict[str, Any]`: The kwargs dict with its tensors moved to `device`.
    """

    def needs_move(src_device, tgt_device):
        return src_device != tgt_device

    for k, v in kwargs.items():
        # Handle nested dicts
        if isinstance(v, dict):
            for k_, v_ in v.items():
                if isinstance(v_, torch.Tensor):
                    if needs_move(v_.device, device):
                        v[k_] = v_.to(device=device)

        # Handle tensor types
        elif isinstance(v, torch.Tensor):
            if needs_move(v.device, device):
                kwargs[k] = v.to(device=device)

        # Handle past_key_value tuples
        elif k == "past_key_values":
            if v is not None:
                new_past_key_values = ()
                for layer_past in v:
                    new_layer_past = ()
                    for past_state in layer_past:
                        if needs_move(past_state.device, device):
                            new_layer_past += (past_state.to(device=device),)
                        else:
                            new_layer_past += (past_state,)
                    new_past_key_values += (new_layer_past,)
                kwargs[k] = new_past_key_values

    return kwargs