in geospatial/data/StreamingDatasets.py [0:0]
def stream_chips(self):
for img_fn, label_fn, group in self.stream_tile_fns():
num_skipped_chips = 0
# Open file pointers
img_fp = rasterio.open(img_fn, "r")
label_fp = rasterio.open(label_fn, "r") if self.use_labels else None
height, width = img_fp.shape
if self.use_labels: # garuntee that our label mask has the same dimensions as our imagery
t_height, t_width = label_fp.shape
assert height == t_height and width == t_width
try:
# If we aren't in windowed sampling mode then we should read the entire tile up front
if not self.windowed_sampling:
img_data = np.rollaxis(img_fp.read(), 0, 3)
if self.use_labels:
label_data = label_fp.read().squeeze() # assume the label geotiff has a single channel
for i in range(self.num_chips_per_tile):
# Select the top left pixel of our chip randomly
x = np.random.randint(0, width-self.chip_size)
y = np.random.randint(0, height-self.chip_size)
# Read imagery / labels
img = None
labels = None
if self.windowed_sampling:
img = np.rollaxis(img_fp.read(window=Window(x, y, self.chip_size, self.chip_size)), 0, 3)
if self.use_labels:
labels = label_fp.read(window=Window(x, y, self.chip_size, self.chip_size)).squeeze()
else:
img = img_data[y:y+self.chip_size, x:x+self.chip_size, :]
if self.use_labels:
labels = label_data[y:y+self.chip_size, x:x+self.chip_size]
# Check for no data
if self.nodata_check is not None:
if self.use_labels:
skip_chip = self.nodata_check(img, labels)
else:
skip_chip = self.nodata_check(img)
if skip_chip: # The current chip has been identified as invalid by the `nodata_check(...)` method
num_skipped_chips += 1
continue
# Transform the imagery
if self.image_transform is not None:
if self.groups is None:
img = self.image_transform(img)
else:
img = self.image_transform(img, group)
else:
img = torch.from_numpy(img).squeeze()
# Transform the labels
if self.use_labels:
if self.label_transform is not None:
if self.groups is None:
labels = self.label_transform(labels)
else:
labels = self.label_transform(labels, group)
else:
labels = torch.from_numpy(labels).squeeze()
# Note, that img should be a torch "Double" type (i.e. a np.float32) and labels should be a torch "Long" type (i.e. np.int64)
if self.use_labels:
yield img, labels
else:
yield img
except RasterioIOError as e: # NOTE(caleb): I put this here to catch weird errors that I was seeing occasionally when trying to read from COGS - I don't remember the details though
print("WARNING: Reading %s failed, skipping..." % (img_fn))
# Close file pointers
img_fp.close()
if self.use_labels:
label_fp.close()
if num_skipped_chips>0 and self.verbose:
print("We skipped %d chips on %s" % (img_fn))