in src/image_gen_aux/modeling_utils.py [0:0]
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
"""
Gets the device of a PyTorch module's parameters or buffers.
Args:
parameter (`torch.nn.Module`): The PyTorch module from which to get the device.
Returns:
`torch.device`: The device of the module's parameters or buffers.
"""
try:
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
return next(parameters_and_buffers).device
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device