in dataset.py [0:0]
def __getitem__(self, index):
# Ensure the index is smallet than the number of samples in the dataset, otherwise return error
assert index <= len(self), 'index range error'
# Get the image path
imgpath = self.lines[index].rstrip()
# Decide which size you are going to resize the image depending on the epoch (10, 20, etc.)
if self.train and index % self.batch_size== 0:
if self.seen < 10*self.nbatches*self.batch_size:
width = 13*self.cell_size
self.shape = (width, width)
elif self.seen < 20*self.nbatches*self.batch_size:
width = (random.randint(0,7) + 13)*self.cell_size
self.shape = (width, width)
elif self.seen < 30*self.nbatches*self.batch_size:
width = (random.randint(0,9) + 12)*self.cell_size
self.shape = (width, width)
elif self.seen < 40*self.nbatches*self.batch_size:
width = (random.randint(0,11) + 11)*self.cell_size
self.shape = (width, width)
elif self.seen < 50*self.nbatches*self.batch_size:
width = (random.randint(0,13) + 10)*self.cell_size
self.shape = (width, width)
elif self.seen < 60*self.nbatches*self.batch_size:
width = (random.randint(0,15) + 9)*self.cell_size
self.shape = (width, width)
elif self.seen < 70*self.nbatches*self.batch_size:
width = (random.randint(0,17) + 8)*self.cell_size
self.shape = (width, width)
else:
width = (random.randint(0,19) + 7)*self.cell_size
self.shape = (width, width)
if self.train:
# Decide on how much data augmentation you are going to apply
jitter = 0.2
hue = 0.1
saturation = 1.5
exposure = 1.5
# Get background image path
random_bg_index = random.randint(0, len(self.bg_file_names) - 1)
bgpath = self.bg_file_names[random_bg_index]
# Get the data augmented image and their corresponding labels
img, label = load_data_detection(imgpath, self.shape, jitter, hue, saturation, exposure, bgpath, self.num_keypoints, self.max_num_gt)
# Convert the labels to PyTorch variables
label = torch.from_numpy(label)
else:
# Get the validation image, resize it to the network input size
img = Image.open(imgpath).convert('RGB')
if self.shape:
img = img.resize(self.shape)
# Read the validation labels, allow upto 50 ground-truth objects in an image
labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
num_labels = 2*self.num_keypoints+3 # +2 for ground-truth of width/height , +1 for class label
label = torch.zeros(self.max_num_gt*num_labels)
if os.path.getsize(labpath):
ow, oh = img.size
tmp = torch.from_numpy(read_truths_args(labpath))
tmp = tmp.view(-1)
tsz = tmp.numel()
if tsz > self.max_num_gt*num_labels:
label = tmp[0:self.max_num_gt*num_labels]
elif tsz > 0:
label[0:tsz] = tmp
# Tranform the image data to PyTorch tensors
if self.transform is not None:
img = self.transform(img)
# If there is any PyTorch-specific transformation, transform the label data
if self.target_transform is not None:
label = self.target_transform(label)
# Increase the number of seen examples
self.seen = self.seen + self.num_workers
# Return the retrieved image and its corresponding label
return (img, label)