def stream_chips()

in geospatial/data/StreamingDatasets.py [0:0]


    def stream_chips(self):
        for img_fn, label_fn, group in self.stream_tile_fns():
            
            num_skipped_chips = 0

            # Open file pointers
            img_fp = rasterio.open(img_fn, "r")
            label_fp = rasterio.open(label_fn, "r") if self.use_labels else None

            height, width = img_fp.shape
            if self.use_labels: # garuntee that our label mask has the same dimensions as our imagery
                t_height, t_width = label_fp.shape
                assert height == t_height and width == t_width

            try:
                # If we aren't in windowed sampling mode then we should read the entire tile up front
                if not self.windowed_sampling:
                    img_data = np.rollaxis(img_fp.read(), 0, 3)
                    if self.use_labels:
                        label_data = label_fp.read().squeeze() # assume the label geotiff has a single channel


                for i in range(self.num_chips_per_tile):
                    # Select the top left pixel of our chip randomly
                    x = np.random.randint(0, width-self.chip_size)
                    y = np.random.randint(0, height-self.chip_size)

                    # Read imagery / labels
                    img = None
                    labels = None
                    if self.windowed_sampling:
                        img = np.rollaxis(img_fp.read(window=Window(x, y, self.chip_size, self.chip_size)), 0, 3)
                        if self.use_labels:
                            labels = label_fp.read(window=Window(x, y, self.chip_size, self.chip_size)).squeeze()
                    else:
                        img = img_data[y:y+self.chip_size, x:x+self.chip_size, :]
                        if self.use_labels:
                            labels = label_data[y:y+self.chip_size, x:x+self.chip_size]

                    # Check for no data
                    if self.nodata_check is not None:
                        if self.use_labels:
                            skip_chip = self.nodata_check(img, labels)
                        else:
                            skip_chip = self.nodata_check(img)

                        if skip_chip: # The current chip has been identified as invalid by the `nodata_check(...)` method
                            num_skipped_chips += 1
                            continue

                    # Transform the imagery
                    if self.image_transform is not None:
                        if self.groups is None:
                            img = self.image_transform(img)
                        else:
                            img = self.image_transform(img, group)
                    else:
                        img = torch.from_numpy(img).squeeze()

                    # Transform the labels
                    if self.use_labels:
                        if self.label_transform is not None:
                            if self.groups is None:
                                labels = self.label_transform(labels)
                            else:
                                labels = self.label_transform(labels, group)
                        else:
                            labels = torch.from_numpy(labels).squeeze()


                    # Note, that img should be a torch "Double" type (i.e. a np.float32) and labels should be a torch "Long" type (i.e. np.int64)
                    if self.use_labels:
                        yield img, labels
                    else:
                        yield img
            except RasterioIOError as e: # NOTE(caleb): I put this here to catch weird errors that I was seeing occasionally when trying to read from COGS - I don't remember the details though
                print("WARNING: Reading %s failed, skipping..." % (img_fn))

            # Close file pointers
            img_fp.close()
            if self.use_labels:
                label_fp.close()

            if num_skipped_chips>0 and self.verbose:
                print("We skipped %d chips on %s" % (img_fn))