def _tensorize()

in src/datasets/formatting/jax_formatter.py [0:0]


    def _tensorize(self, value):
        import jax
        import jax.numpy as jnp

        if isinstance(value, (str, bytes, type(None))):
            return value
        elif isinstance(value, (np.character, np.ndarray)) and np.issubdtype(value.dtype, np.character):
            return value.tolist()

        default_dtype = {}

        if isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.integer):
            # the default int precision depends on the jax config
            # see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
            if jax.config.jax_enable_x64:
                default_dtype = {"dtype": jnp.int64}
            else:
                default_dtype = {"dtype": jnp.int32}
        elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating):
            default_dtype = {"dtype": jnp.float32}

        if config.PIL_AVAILABLE and "PIL" in sys.modules:
            import PIL.Image

            if isinstance(value, PIL.Image.Image):
                value = np.asarray(value)
        if config.TORCHVISION_AVAILABLE and "torchvision" in sys.modules:
            from torchvision.io import VideoReader

            if isinstance(value, VideoReader):
                return value  # TODO(QL): set output to jax arrays ?
        if config.TORCHCODEC_AVAILABLE and "torchcodec" in sys.modules:
            from torchcodec.decoders import AudioDecoder, VideoDecoder

            if isinstance(value, (VideoDecoder, AudioDecoder)):
                return value  # TODO(QL): set output to jax arrays ?

        # using global variable since `jaxlib.xla_extension.Device` is not serializable neither
        # with `pickle` nor with `dill`, so we need to use a global variable instead
        global DEVICE_MAPPING
        if DEVICE_MAPPING is None:
            DEVICE_MAPPING = self._map_devices_to_str()

        with jax.default_device(DEVICE_MAPPING[self.device]):
            # calling jnp.array on a np.ndarray does copy the data
            # see https://github.com/google/jax/issues/4486
            return jnp.array(value, **{**default_dtype, **self.jnp_array_kwargs})