in datasets/iamdb.py [0:0]
def __init__(self, data_path, preprocessor, split, augment=False):
forms = load_metadata(
data_path, preprocessor.wordsep, use_words=preprocessor.use_words
)
# Get split keys:
splits = SPLITS.get(split, None)
if splits is None:
split_names = ", ".join(f"'{k}'" for k in SPLITS.keys())
raise ValueError(f"Invalid split {split}, must be in [{split_names}].")
split_keys = []
for s in splits:
with open(os.path.join(data_path, f"{s}.txt"), "r") as fid:
split_keys.extend((l.strip() for l in fid))
self.preprocessor = preprocessor
# setup image transforms:
self.transforms = []
if augment:
self.transforms.extend(
[
RandomResizeCrop(),
transforms.RandomRotation(2, fill=(255,)),
transforms.ColorJitter(0.5, 0.5, 0.5, 0.5),
]
)
self.transforms.extend(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.912], std=[0.168]),
]
)
self.transforms = transforms.Compose(self.transforms)
# Load each image:
images = []
text = []
for key, examples in forms.items():
for example in examples:
if example["key"] not in split_keys:
continue
img_file = os.path.join(data_path, f"{key}.png")
images.append((img_file, example["box"], preprocessor.num_features))
text.append(example["text"])
with mp.Pool(processes=16) as pool:
images = pool.map(load_image, images)
self.dataset = list(zip(images, text))