def create_model()

in people-and-planet-ai/timeseries-classification/trainer.py [0:0]


def create_model(train_dataset: tf.data.Dataset) -> keras.Model:
    input_layers = {
        name: keras.layers.Input(shape=spec.shape, dtype=spec.dtype, name=name)
        for name, spec in INPUTS_SPEC.items()
    }

    def normalize(name: str) -> keras.layers.Layer:
        layer = keras.layers.Normalization(name=f"{name}_normalized")
        layer.adapt(train_dataset.map(lambda inputs, outputs: inputs[name]))
        return layer(input_layers[name])

    def direction(course_name: str) -> keras.layers.Layer:
        class Direction(keras.layers.Layer):
            def call(self: a, course: tf.Tensor) -> tf.Tensor:
                x = tf.cos(course)
                y = tf.sin(course)
                return tf.concat([x, y], axis=-1)

        input_layer = input_layers[course_name]
        return Direction(name=f"{course_name}_direction")(input_layer)

    def geo_point(lat_name: str, lon_name: str) -> keras.layers.Layer:
        # We transform each (lat, lon) pair into a 3D point in the unit sphere.
        #   https://en.wikipedia.org/wiki/Spherical_coordinate_system#Cartesian_coordinates
        class GeoPoint(keras.layers.Layer):
            def call(self: a, latlon: tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
                lat, lon = latlon
                x = tf.cos(lon) * tf.sin(lat)
                y = tf.sin(lon) * tf.sin(lat)
                z = tf.cos(lat)
                return tf.concat([x, y, z], axis=-1)

        lat_lon_input_layers = (input_layers[lat_name], input_layers[lon_name])
        return GeoPoint(name=f"{lat_name}_{lon_name}")(lat_lon_input_layers)

    def sequential_layers(
        first_layer: keras.layers.Layer, *layers: keras.layers.Layer
    ) -> keras.layers.Layer:
        return reduce(lambda layer, result: result(layer), layers, first_layer)

    preprocessed_inputs = [
        normalize("distance_from_port"),
        normalize("speed"),
        direction("course"),
        geo_point("lat", "lon"),
    ]

    output_layers = {
        "is_fishing": sequential_layers(
            keras.layers.concatenate(preprocessed_inputs, name="deep_layers"),
            keras.layers.Conv1D(
                filters=32,
                kernel_size=PADDING + 1,
                data_format="channels_last",
                activation="relu",
            ),
            keras.layers.Dense(16, activation="relu"),
            keras.layers.Dense(1, activation="sigmoid", name="is_fishing"),
        )
    }
    return keras.Model(input_layers, output_layers)