in src/chug/wds/decode.py [0:0]
def __call__(self, key, data):
"""
Args:
key: file name extension
data: data to be decoded
"""
extension = re.sub(r".*[.]", "", key)
if extension not in {'pdf', 'tiff', 'tif'}:
return None
imagespec = self.imagespec
atype, etype, mode = wds.autodecode.imagespecs[imagespec]
select_random = False
if self.page_sampling == 'random':
page_indices = None
select_random = True
elif self.page_sampling == 'first':
page_indices = [0] # first page
elif self.page_sampling == 'last':
page_indices = [-1]
else:
assert False
if extension == 'pdf':
# pdf document
result, num_pages = decode_pdf_pages(
data,
image_mode=mode.upper(),
page_indices=page_indices,
select_random=select_random,
)
else:
# multi-page image doc (e.g. tiff)
result, num_pages = decode_image_pages(
data,
image_mode=mode.upper(),
page_indices=page_indices,
select_random=select_random,
)
if atype == "pil":
return result
result = np.asarray(result)
if etype == "float":
result = result.astype(np.float32) / 255.0
assert result.ndim in [2, 3], result.shape
assert mode in ["l", "rgb", "rgba"], mode
if mode == "l":
if result.ndim == 3:
result = np.mean(result[:, :, :3], axis=2)
elif mode == "rgb":
if result.ndim == 2:
result = np.repeat(result[:, :, np.newaxis], 3, axis=2)
elif result.shape[2] == 4:
result = result[:, :, :3]
elif mode == "rgba":
if result.ndim == 2:
result = np.repeat(result[:, :, np.newaxis], 4, axis=2)
result[:, :, 3] = 255
elif result.shape[2] == 3:
result = np.concatenate(
[result, 255 * np.ones(result.shape[:2])], axis=2
)
assert atype in ["numpy", "torch"], atype
if atype == "numpy":
return result
elif atype == "torch":
import torch
if result.ndim == 3:
return torch.from_numpy(result.transpose(2, 0, 1))
else:
return torch.from_numpy(result)
return None