people-and-planet-ai/weather-forecasting/serving/weather-model/weather/trainer.py (124 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Trains a model to predict precipitation.""" from __future__ import annotations from glob import glob import os from datasets.arrow_dataset import Dataset from datasets.dataset_dict import DatasetDict import numpy as np from transformers import Trainer, TrainingArguments from weather.model import WeatherModel # Default values. EPOCHS = 100 BATCH_SIZE = 512 TRAIN_TEST_RATIO = 0.9 # Constants. NUM_DATASET_READ_PROC = 16 # number of processes to read data files in parallel NUM_DATASET_PROC = os.cpu_count() or 8 # number of processes for CPU transformations def read_dataset(data_path: str, train_test_ratio: float) -> DatasetDict: """Reads data files into a Dataset with train/test splits.""" def read_data_file(item: dict[str, str]) -> dict[str, np.ndarray]: with open(item["filename"], "rb") as f: npz = np.load(f) return {"inputs": npz["inputs"], "labels": npz["labels"]} def flatten(batch: dict) -> dict: return {key: np.concatenate(values) for key, values in batch.items()} files = glob(os.path.join(data_path, "*.npz")) dataset = ( Dataset.from_dict({"filename": files}) .map( read_data_file, num_proc=NUM_DATASET_READ_PROC, remove_columns=["filename"], ) .map(flatten, batched=True, num_proc=NUM_DATASET_PROC) ) return dataset.train_test_split(train_size=train_test_ratio, shuffle=True) def augmented(dataset: Dataset) -> Dataset: """Augments dataset by rotating and flipping the examples.""" def augment(values: list) -> np.ndarray: transformed = [ np.rot90(values, 0, (1, 2)), np.rot90(values, 1, (1, 2)), np.rot90(values, 2, (1, 2)), np.rot90(values, 3, (1, 2)), np.flip(np.rot90(values, 0, (1, 2)), axis=1), np.flip(np.rot90(values, 1, (1, 2)), axis=1), np.flip(np.rot90(values, 2, (1, 2)), axis=1), np.flip(np.rot90(values, 3, (1, 2)), axis=1), ] return np.concatenate(transformed) return dataset.map( lambda batch: {key: augment(values) for key, values in batch.items()}, batched=True, num_proc=NUM_DATASET_PROC, ) def run( data_path: str, model_path: str, epochs: int = EPOCHS, batch_size: int = BATCH_SIZE, train_test_ratio: float = TRAIN_TEST_RATIO, from_checkpoint: bool = False, ) -> None: """Trains a new WeatherModel. Args: data_path: Directory path to read data files from. model_path Directory path to write the trained model to. epochs: Number of times to go through the training dataset. batch_size: Number of training examples to learn from at once. train_test_ratio: Ratio of examples to use for training and for testing. from_checkpoint: Whether or not to resume from latest checkpoint. """ print(f"data_path: {data_path}") print(f"model_path: {model_path}") print(f"epochs: {epochs}") print(f"batch_size: {batch_size}") print(f"train_test_ratio: {train_test_ratio}") print("-" * 40) dataset = read_dataset(data_path, train_test_ratio) print(dataset) model = WeatherModel.create(dataset["train"]["inputs"]) print(model.config) print(model) training_args = TrainingArguments( output_dir=os.path.join(model_path, "checkpoints"), per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, num_train_epochs=epochs, logging_strategy="epoch", evaluation_strategy="epoch", ) trainer = Trainer( model, training_args, train_dataset=augmented(dataset["train"]), eval_dataset=dataset["test"], ) trainer.train(resume_from_checkpoint=from_checkpoint) trainer.save_model(model_path) def main() -> None: import argparse parser = argparse.ArgumentParser() parser.add_argument( "--data-path", required=True, help="Directory path to read data files from.", ) parser.add_argument( "--model-path", required=True, help="Directory path to write the trained model to.", ) parser.add_argument( "--epochs", type=int, default=EPOCHS, help="Number of times to go through the training dataset.", ) parser.add_argument( "--batch-size", type=int, default=BATCH_SIZE, help="Number of training examples to learn from at once.", ) parser.add_argument( "--train-test-ratio", type=float, default=TRAIN_TEST_RATIO, help="Ratio of examples to use for training and for testing.", ) parser.add_argument( "--from-checkpoint", action="store_true", help="Whether or not to resume from latest checkpoint.", ) args = parser.parse_args() run(**vars(args)) if __name__ == "__main__": main()