people-and-planet-ai/timeseries-classification/trainer.py (228 lines of code) (raw):

# Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from functools import reduce import logging import os from typing import TypeVar import tensorflow as tf from tensorflow import keras a = TypeVar("a") INPUTS_SPEC = { "distance_from_port": tf.TensorSpec(shape=(None, 1), dtype=tf.float32), "speed": tf.TensorSpec(shape=(None, 1), dtype=tf.float32), "course": tf.TensorSpec(shape=(None, 1), dtype=tf.float32), "lat": tf.TensorSpec(shape=(None, 1), dtype=tf.float32), "lon": tf.TensorSpec(shape=(None, 1), dtype=tf.float32), } OUTPUTS_SPEC = { "is_fishing": tf.TensorSpec(shape=(None, 1), dtype=tf.float32), } PADDING = 24 def validated( tensor_dict: dict[str, tf.Tensor], spec_dict: dict[str, tf.TypeSpec], ) -> dict[str, tf.Tensor]: for field, spec in spec_dict.items(): if field not in tensor_dict: raise KeyError( f"missing field '{field}', got={tensor_dict.keys()}, expected={spec_dict.keys()}" ) if not spec.dtype.is_compatible_with(tensor_dict[field].dtype): raise TypeError( f"incompatible type in '{field}', got={tensor_dict[field].dtype}, expected={spec.dtype}" ) if not spec.shape.is_compatible_with(tensor_dict[field].shape): raise ValueError( f"incompatible shape in '{field}', got={tensor_dict[field].shape}, expected={spec.shape}" ) return tensor_dict def serialize(value_dict: dict[str, a]) -> bytes: spec_dict = {**INPUTS_SPEC, **OUTPUTS_SPEC} tensor_dict = { field: tf.convert_to_tensor(value, spec_dict[field].dtype) for field, value in value_dict.items() } validated_tensor_dict = validated(tensor_dict, spec_dict) example = tf.train.Example( features=tf.train.Features( feature={ field: tf.train.Feature( bytes_list=tf.train.BytesList( value=[tf.io.serialize_tensor(value).numpy()] ) ) for field, value in validated_tensor_dict.items() } ) ) return example.SerializeToString() def deserialize( serialized_example: bytes, ) -> tuple[dict[str, tf.Tensor], dict[str, tf.Tensor]]: features = { field: tf.io.FixedLenFeature(shape=(), dtype=tf.string) for field in [*INPUTS_SPEC.keys(), *OUTPUTS_SPEC.keys()] } example = tf.io.parse_example(serialized_example, features) def parse_tensor(bytes_value: bytes, spec: tf.TypeSpec) -> tf.Tensor: tensor = tf.io.parse_tensor(bytes_value, spec.dtype) tensor.set_shape(spec.shape) return tensor def parse_features(spec_dict: dict[str, tf.TypeSpec]) -> dict[str, tf.Tensor]: tensor_dict = { field: parse_tensor(bytes_value, spec_dict[field]) for field, bytes_value in example.items() if field in spec_dict } return validated(tensor_dict, spec_dict) return parse_features(INPUTS_SPEC), parse_features(OUTPUTS_SPEC) def create_dataset(data_dir: str, batch_size: int) -> tf.data.Dataset: file_names = tf.io.gfile.glob(f"{data_dir}/*") return ( tf.data.TFRecordDataset(file_names, compression_type="GZIP") .map(deserialize, num_parallel_calls=tf.data.AUTOTUNE) .shuffle(batch_size * 128) .batch(batch_size, drop_remainder=True) .prefetch(tf.data.AUTOTUNE) ) def create_model(train_dataset: tf.data.Dataset) -> keras.Model: input_layers = { name: keras.layers.Input(shape=spec.shape, dtype=spec.dtype, name=name) for name, spec in INPUTS_SPEC.items() } def normalize(name: str) -> keras.layers.Layer: layer = keras.layers.Normalization(name=f"{name}_normalized") layer.adapt(train_dataset.map(lambda inputs, outputs: inputs[name])) return layer(input_layers[name]) def direction(course_name: str) -> keras.layers.Layer: class Direction(keras.layers.Layer): def call(self: a, course: tf.Tensor) -> tf.Tensor: x = tf.cos(course) y = tf.sin(course) return tf.concat([x, y], axis=-1) input_layer = input_layers[course_name] return Direction(name=f"{course_name}_direction")(input_layer) def geo_point(lat_name: str, lon_name: str) -> keras.layers.Layer: # We transform each (lat, lon) pair into a 3D point in the unit sphere. # https://en.wikipedia.org/wiki/Spherical_coordinate_system#Cartesian_coordinates class GeoPoint(keras.layers.Layer): def call(self: a, latlon: tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor: lat, lon = latlon x = tf.cos(lon) * tf.sin(lat) y = tf.sin(lon) * tf.sin(lat) z = tf.cos(lat) return tf.concat([x, y, z], axis=-1) lat_lon_input_layers = (input_layers[lat_name], input_layers[lon_name]) return GeoPoint(name=f"{lat_name}_{lon_name}")(lat_lon_input_layers) def sequential_layers( first_layer: keras.layers.Layer, *layers: keras.layers.Layer ) -> keras.layers.Layer: return reduce(lambda layer, result: result(layer), layers, first_layer) preprocessed_inputs = [ normalize("distance_from_port"), normalize("speed"), direction("course"), geo_point("lat", "lon"), ] output_layers = { "is_fishing": sequential_layers( keras.layers.concatenate(preprocessed_inputs, name="deep_layers"), keras.layers.Conv1D( filters=32, kernel_size=PADDING + 1, data_format="channels_last", activation="relu", ), keras.layers.Dense(16, activation="relu"), keras.layers.Dense(1, activation="sigmoid", name="is_fishing"), ) } return keras.Model(input_layers, output_layers) def run( train_data_dir: str, eval_data_dir: str, train_epochs: int, batch_size: int, model_dir: str, checkpoint_dir: str, tensorboard_dir: str, ) -> None: # For this sample we are using a mirrored distribution strategy, # which consists of a single machine with multiple GPUs. # https://blog.tensorflow.org/2020/12/getting-started-with-distributed-tensorflow-on-gcp.html distributed_strategy = tf.distribute.MirroredStrategy() # distributed_strategy = tf.distribute.get_strategy() # Create the training and evaluation datasets from the TFRecord files. logging.info("Creating datasets") train_batch_size = batch_size * distributed_strategy.num_replicas_in_sync train_dataset = create_dataset(train_data_dir, train_batch_size) eval_dataset = create_dataset(eval_data_dir, batch_size) # Create and compile the model inside the distribution strategy scope. with distributed_strategy.scope(): logging.info("Creating the model") model = create_model(train_dataset) logging.info("Compiling the model") model.compile( optimizer="adam", # https://keras.io/api/optimizers loss={"is_fishing": "binary_crossentropy"}, # https://keras.io/api/losses metrics={"is_fishing": ["accuracy"]}, # https://keras.io/api/metrics ) # Train the model. logging.info("Training the model") model.fit( train_dataset, epochs=train_epochs, validation_data=eval_dataset, callbacks=[ keras.callbacks.TensorBoard(tensorboard_dir, update_freq="batch"), keras.callbacks.ModelCheckpoint( filepath=checkpoint_dir + "/{epoch}", save_best_only=True, # Only save a model if `val_loss` has improved. monitor="val_loss", verbose=1, ), ], ) # Save the trained model. logging.info(f"Saving the model: {model_dir}") model.save(model_dir) if __name__ == "__main__": import argparse # TODO: Have either: hardcoded default values if possible, or have everything required parser = argparse.ArgumentParser() parser.add_argument( "--train-data-dir", required=True, help="Cloud Storage directory containing training TFRecord files.", ) parser.add_argument( "--eval-data-dir", required=True, help="Cloud Storage directory containing evaluation TFRecord files.", ) parser.add_argument( "--train-epochs", type=int, required=True, help="Number of times to go through the training dataset.", ) parser.add_argument( "--batch-size", type=int, required=True, help="Batch size for the training and evaluation datasets.", ) parser.add_argument( "--model-dir", default=os.environ.get("AIP_MODEL_DIR", "model"), help="Directory to save the trained model.", ) parser.add_argument( "--checkpoint-dir", default=os.environ.get("AIP_CHECKPOINT_DIR", "checkpoints"), help="Directory to save model checkpoints during training.", ) parser.add_argument( "--tensorboard-dir", default=os.environ.get("AIP_TENSORBOARD_LOG_DIR", "tensorboard"), help="Directory to save TensorBoard logs.", ) args = parser.parse_args() logging.getLogger().setLevel(logging.INFO) run( train_data_dir=args.train_data_dir, eval_data_dir=args.eval_data_dir, train_epochs=args.train_epochs, batch_size=args.batch_size, model_dir=args.model_dir, checkpoint_dir=args.checkpoint_dir, tensorboard_dir=args.tensorboard_dir, )