people-and-planet-ai/weather-forecasting/serving/weather-model/weather/trainer.py (124 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trains a model to predict precipitation."""
from __future__ import annotations
from glob import glob
import os
from datasets.arrow_dataset import Dataset
from datasets.dataset_dict import DatasetDict
import numpy as np
from transformers import Trainer, TrainingArguments
from weather.model import WeatherModel
# Default values.
EPOCHS = 100
BATCH_SIZE = 512
TRAIN_TEST_RATIO = 0.9
# Constants.
NUM_DATASET_READ_PROC = 16 # number of processes to read data files in parallel
NUM_DATASET_PROC = os.cpu_count() or 8 # number of processes for CPU transformations
def read_dataset(data_path: str, train_test_ratio: float) -> DatasetDict:
"""Reads data files into a Dataset with train/test splits."""
def read_data_file(item: dict[str, str]) -> dict[str, np.ndarray]:
with open(item["filename"], "rb") as f:
npz = np.load(f)
return {"inputs": npz["inputs"], "labels": npz["labels"]}
def flatten(batch: dict) -> dict:
return {key: np.concatenate(values) for key, values in batch.items()}
files = glob(os.path.join(data_path, "*.npz"))
dataset = (
Dataset.from_dict({"filename": files})
.map(
read_data_file,
num_proc=NUM_DATASET_READ_PROC,
remove_columns=["filename"],
)
.map(flatten, batched=True, num_proc=NUM_DATASET_PROC)
)
return dataset.train_test_split(train_size=train_test_ratio, shuffle=True)
def augmented(dataset: Dataset) -> Dataset:
"""Augments dataset by rotating and flipping the examples."""
def augment(values: list) -> np.ndarray:
transformed = [
np.rot90(values, 0, (1, 2)),
np.rot90(values, 1, (1, 2)),
np.rot90(values, 2, (1, 2)),
np.rot90(values, 3, (1, 2)),
np.flip(np.rot90(values, 0, (1, 2)), axis=1),
np.flip(np.rot90(values, 1, (1, 2)), axis=1),
np.flip(np.rot90(values, 2, (1, 2)), axis=1),
np.flip(np.rot90(values, 3, (1, 2)), axis=1),
]
return np.concatenate(transformed)
return dataset.map(
lambda batch: {key: augment(values) for key, values in batch.items()},
batched=True,
num_proc=NUM_DATASET_PROC,
)
def run(
data_path: str,
model_path: str,
epochs: int = EPOCHS,
batch_size: int = BATCH_SIZE,
train_test_ratio: float = TRAIN_TEST_RATIO,
from_checkpoint: bool = False,
) -> None:
"""Trains a new WeatherModel.
Args:
data_path: Directory path to read data files from.
model_path Directory path to write the trained model to.
epochs: Number of times to go through the training dataset.
batch_size: Number of training examples to learn from at once.
train_test_ratio: Ratio of examples to use for training and for testing.
from_checkpoint: Whether or not to resume from latest checkpoint.
"""
print(f"data_path: {data_path}")
print(f"model_path: {model_path}")
print(f"epochs: {epochs}")
print(f"batch_size: {batch_size}")
print(f"train_test_ratio: {train_test_ratio}")
print("-" * 40)
dataset = read_dataset(data_path, train_test_ratio)
print(dataset)
model = WeatherModel.create(dataset["train"]["inputs"])
print(model.config)
print(model)
training_args = TrainingArguments(
output_dir=os.path.join(model_path, "checkpoints"),
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=epochs,
logging_strategy="epoch",
evaluation_strategy="epoch",
)
trainer = Trainer(
model,
training_args,
train_dataset=augmented(dataset["train"]),
eval_dataset=dataset["test"],
)
trainer.train(resume_from_checkpoint=from_checkpoint)
trainer.save_model(model_path)
def main() -> None:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-path",
required=True,
help="Directory path to read data files from.",
)
parser.add_argument(
"--model-path",
required=True,
help="Directory path to write the trained model to.",
)
parser.add_argument(
"--epochs",
type=int,
default=EPOCHS,
help="Number of times to go through the training dataset.",
)
parser.add_argument(
"--batch-size",
type=int,
default=BATCH_SIZE,
help="Number of training examples to learn from at once.",
)
parser.add_argument(
"--train-test-ratio",
type=float,
default=TRAIN_TEST_RATIO,
help="Ratio of examples to use for training and for testing.",
)
parser.add_argument(
"--from-checkpoint",
action="store_true",
help="Whether or not to resume from latest checkpoint.",
)
args = parser.parse_args()
run(**vars(args))
if __name__ == "__main__":
main()