def map_pixels()

in dall_e/utils.py [0:0]


def map_pixels(x: torch.Tensor) -> torch.Tensor:
	if len(x.shape) != 4:
		raise ValueError('expected input to be 4d')
	if x.dtype != torch.float:
		raise ValueError('expected input to have type float')

	return (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps