def __call__()

in src/chug/wds/decode.py [0:0]


    def __call__(self, key, data):
        """
        Args:
            key: file name extension
            data: data to be decoded
        """
        extension = re.sub(r".*[.]", "", key)
        if extension not in {'pdf', 'tiff', 'tif'}:
            return None

        imagespec = self.imagespec
        atype, etype, mode = wds.autodecode.imagespecs[imagespec]

        select_random = False
        if self.page_sampling == 'random':
            page_indices = None
            select_random = True
        elif self.page_sampling == 'first':
            page_indices = [0]  # first page
        elif self.page_sampling == 'last':
            page_indices = [-1]
        else:
            assert False

        if extension == 'pdf':
            # pdf document
            result, num_pages = decode_pdf_pages(
                data,
                image_mode=mode.upper(),
                page_indices=page_indices,
                select_random=select_random,
            )
        else:
            # multi-page image doc (e.g. tiff)
            result, num_pages = decode_image_pages(
                data,
                image_mode=mode.upper(),
                page_indices=page_indices,
                select_random=select_random,
            )

        if atype == "pil":
            return result

        result = np.asarray(result)

        if etype == "float":
            result = result.astype(np.float32) / 255.0

        assert result.ndim in [2, 3], result.shape
        assert mode in ["l", "rgb", "rgba"], mode

        if mode == "l":
            if result.ndim == 3:
                result = np.mean(result[:, :, :3], axis=2)
        elif mode == "rgb":
            if result.ndim == 2:
                result = np.repeat(result[:, :, np.newaxis], 3, axis=2)
            elif result.shape[2] == 4:
                result = result[:, :, :3]
        elif mode == "rgba":
            if result.ndim == 2:
                result = np.repeat(result[:, :, np.newaxis], 4, axis=2)
                result[:, :, 3] = 255
            elif result.shape[2] == 3:
                result = np.concatenate(
                    [result, 255 * np.ones(result.shape[:2])], axis=2
                )

        assert atype in ["numpy", "torch"], atype

        if atype == "numpy":
            return result
        elif atype == "torch":
            import torch

            if result.ndim == 3:
                return torch.from_numpy(result.transpose(2, 0, 1))
            else:
                return torch.from_numpy(result)

        return None