people-and-planet-ai/weather-forecasting/create_dataset.py (140 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Creates a dataset to train a machine learning model.""" from __future__ import annotations from collections.abc import Iterator from datetime import datetime, timedelta import logging import random import uuid import apache_beam as beam from apache_beam.io.filesystems import FileSystems from apache_beam.options.pipeline_options import PipelineOptions import ee import numpy as np import requests # Default values. NUM_DATES = 100 MAX_REQUESTS = 20 # default EE request quota MIN_BATCH_SIZE = 100 # Constants. NUM_BINS = 10 MAX_PRECIPITATION = 30 # found empirically MAX_ELEVATION = 6000 # found empirically PATCH_SIZE = 5 START_DATE = datetime(2017, 7, 10) END_DATE = datetime.now() - timedelta(days=30) POLYGON = [(-140.0, 60.0), (-140.0, -60.0), (-10.0, -60.0), (-10.0, 60.0)] def sample_points(date: datetime, num_bins: int = NUM_BINS) -> Iterator[tuple]: """Selects around the same number of points for every classification. Since our labels are numeric continuous values, we convert them into integers within a predifined range. Each integer value is treated as a different classification. From analyzing the precipitation data, most values are within 0 mm/hr and 30 mm/hr of precipitation (rain and snow), but values above 30 mm/hr still exist. So we clamp them to values to between 0 mm/hr and 30 mm/hr, and then bucketize them. We do the same for the elevation, and finally get a "unique" bin number by combining the precipitationd and elevation bins. We do this because most of the precipitation values fall under elevation zero, so the data would be extremely biased. Args: date: The date of interest. num_bins: Number of bins to bucketize values. Yields: (date, lon_lat) pairs. """ from weather import data precipitation_bins = ( data.get_gpm(date) .clamp(0, MAX_PRECIPITATION) .divide(MAX_PRECIPITATION) .multiply(num_bins - 1) .uint8() ) elevation_bins = ( data.get_elevation() .clamp(0, MAX_ELEVATION) .divide(MAX_ELEVATION) .multiply(num_bins - 1) .uint8() ) unique_bins = elevation_bins.multiply(num_bins).add(precipitation_bins) points = unique_bins.stratifiedSample( numPoints=1, region=ee.Geometry.Polygon(POLYGON), scale=data.SCALE, geometries=True, ) for point in points.toList(points.size()).getInfo(): yield (date, point["geometry"]["coordinates"]) def get_training_example( date: datetime, point: tuple, patch_size: int = PATCH_SIZE ) -> tuple: """Gets an (inputs, labels) training example. Args: date: The date of interest. point: A (longitude, latitude) coordinate. patch_size: Size in pixels of the surrounding square patch. Returns: An (inputs, labels) pair of NumPy arrays. """ from weather import data return ( data.get_inputs_patch(date, point, patch_size), data.get_labels_patch(date, point, patch_size), ) def try_get_example(date: datetime, point: tuple) -> Iterator[tuple]: """Wrapper over `get_training_examples` that allows it to simply log errors instead of crashing.""" try: yield get_training_example(date, point) except (requests.exceptions.HTTPError, ee.ee_exception.EEException) as e: logging.error(f"🛑 failed to get example: {date} {point}") logging.exception(e) def write_npz(batch: list[tuple[np.ndarray, np.ndarray]], data_path: str) -> str: """Writes an (inputs, labels) batch into a compressed NumPy file. Args: batch: Batch of (inputs, labels) pairs of NumPy arrays. data_path: Directory path to save files to. Returns: The filename of the data file. """ filename = FileSystems.join(data_path, f"{uuid.uuid4()}.npz") with FileSystems.create(filename) as f: inputs = [x for (x, _) in batch] labels = [y for (_, y) in batch] np.savez_compressed(f, inputs=inputs, labels=labels) logging.info(filename) return filename def run( data_path: str, num_dates: int = NUM_DATES, num_bins: int = NUM_BINS, max_requests: int = MAX_REQUESTS, min_batch_size: int = MIN_BATCH_SIZE, beam_args: list[str] | None = None, ) -> None: """Runs an Apache Beam pipeline to create a dataset. This fetches data from Earth Engine and writes compressed NumPy files. We use `max_requests` to limit the number of concurrent requests to Earth Engine to avoid quota issues. You can request for an increas of quota if you need it. Args: data_path: Directory path to save the data files. num_dates: Number of dates to extract data points from. num_bins: Number of bins to bucketize values. max_requests: Limit the number of concurrent requests to Earth Engine. min_batch_size: Minimum number of examples to write per data file. beam_args: Apache Beam command line arguments to parse as pipeline options. """ random_dates = [ START_DATE + (END_DATE - START_DATE) * random.random() for _ in range(num_dates) ] beam_options = PipelineOptions( beam_args, save_main_session=True, direct_num_workers=max(max_requests, MAX_REQUESTS), # direct runner max_num_workers=max_requests, # distributed runners ) with beam.Pipeline(options=beam_options) as pipeline: ( pipeline | "📆 Random dates" >> beam.Create(random_dates) | "📌 Sample points" >> beam.FlatMap(sample_points, num_bins) | "🃏 Reshuffle" >> beam.Reshuffle() | "📑 Get example" >> beam.FlatMapTuple(try_get_example) | "🗂️ Batch examples" >> beam.BatchElements(min_batch_size) | "📝 Write NPZ files" >> beam.Map(write_npz, data_path) ) def main() -> None: import argparse logging.getLogger().setLevel(logging.INFO) parser = argparse.ArgumentParser() parser.add_argument( "--data-path", required=True, help="Directory path to save the data files", ) parser.add_argument( "--num-dates", type=int, default=NUM_DATES, help="Number of dates to extract data points from.", ) parser.add_argument( "--num-bins", type=int, default=NUM_BINS, help="Number of bins to bucketize values.", ) parser.add_argument( "--max-requests", type=int, default=MAX_REQUESTS, help="Limit the number of concurrent requests to Earth Engine.", ) parser.add_argument( "--min-batch-size", type=int, default=MIN_BATCH_SIZE, help="Minimum number of examples to write per data file.", ) args, beam_args = parser.parse_known_args() run( data_path=args.data_path, num_dates=args.num_dates, num_bins=args.num_bins, max_requests=args.max_requests, min_batch_size=args.min_batch_size, beam_args=beam_args, ) if __name__ == "__main__": main()