people-and-planet-ai/weather-forecasting/visualize.py (100 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions to visualize data.
Color names from https://chir.ag/projects/name-that-color
"""
from __future__ import annotations
import numpy as np
import plotly.graph_objects as graph_objects
from plotly.subplots import make_subplots
def render_rgb_images(
values: np.ndarray, min: float = 0.0, max: float = 1.0
) -> np.ndarray:
"""Renders a numeric NumPy array with shape (width, height, rgb) as an image.
Args:
values: A float array with shape (width, height, rgb).
min: Minimum value in the values.
max: Maximum value in the values.
Returns: An uint8 array with shape (width, height, rgb).
"""
scaled_values = (values - min) / (max - min)
rgb_values = scaled_values.clip(0, 1) * 255
return rgb_values.astype(np.uint8)
def render_palette(
values: np.ndarray, palette: list[str], min: float = 0.0, max: float = 1.0
) -> np.ndarray:
"""Renders a NumPy array with shape (width, height, 1) as an image with a palette.
Args:
values: An uint8 array with shape (width, height, 1).
palette: List of hex encoded colors.
Returns: An uint8 array with shape (width, height, rgb) with colors from the palette.
"""
# Create a color map from a hex color palette.
xs = np.linspace(0, len(palette), 256)
indices = np.arange(len(palette))
red = np.interp(xs, indices, [int(c[0:2], 16) for c in palette])
green = np.interp(xs, indices, [int(c[2:4], 16) for c in palette])
blue = np.interp(xs, indices, [int(c[4:6], 16) for c in palette])
color_map = np.array([red, green, blue]).astype(np.uint8).transpose()
scaled_values = (values - min) / (max - min)
color_indices = (scaled_values.clip(0, 1) * 255).astype(np.uint8)
return np.take(color_map, color_indices, axis=0)
def render_goes16(patch: np.ndarray) -> np.ndarray:
red = patch[:, :, 1] # CMI_C02
green = patch[:, :, 2] # CMI_C03
blue = patch[:, :, 0] # CMI_C01
rgb_patch = np.stack([red, green, blue], axis=-1)
return render_rgb_images(rgb_patch, max=3000)
def render_gpm(patch: np.ndarray) -> np.ndarray:
palette = [
"000096", # Navy blue
"0064ff", # Blue ribbon blue
"00b4ff", # Dodger blue
"33db80", # Shamrock green
"9beb4a", # Conifer green
"ffeb00", # Turbo yellow
"ffb300", # Selective yellow
"ff6400", # Blaze orange
"eb1e00", # Scarlet red
"af0000", # Bright red
]
return render_palette(patch[:, :, 0], palette, max=20)
def render_elevation(patch: np.ndarray) -> np.ndarray:
palette = [
"000000", # Black
"478fcd", # Shakespeare blue
"86c58e", # De York green
"afc35e", # Celery green
"8f7131", # Pesto brown
"b78d4f", # Muddy waters brown
"e2b8a6", # Rose fog pink
"ffffff", # White
]
return render_palette(patch[:, :, 0], palette, max=3000)
def show_inputs(patch: np.ndarray) -> None:
fig = make_subplots(rows=2, cols=4)
fig.add_trace(graph_objects.Image(z=render_gpm(patch[:, :, 0:1])), row=1, col=1)
fig.add_trace(graph_objects.Image(z=render_gpm(patch[:, :, 1:2])), row=1, col=2)
fig.add_trace(graph_objects.Image(z=render_gpm(patch[:, :, 2:3])), row=1, col=3)
fig.add_trace(graph_objects.Image(z=render_goes16(patch[:, :, 3:19])), row=2, col=1)
fig.add_trace(
graph_objects.Image(z=render_goes16(patch[:, :, 19:35])), row=2, col=2
)
fig.add_trace(
graph_objects.Image(z=render_goes16(patch[:, :, 35:51])), row=2, col=3
)
fig.add_trace(
graph_objects.Image(z=render_elevation(patch[:, :, 51:52])), row=1, col=4
)
fig.update_layout(height=500, margin=dict(l=0, r=0, b=0, t=0))
fig.show()
def show_outputs(patch: np.ndarray) -> None:
fig = make_subplots(rows=1, cols=2)
fig.add_trace(graph_objects.Image(z=render_gpm(patch[:, :, 0:1])), row=1, col=1)
fig.add_trace(graph_objects.Image(z=render_gpm(patch[:, :, 1:2])), row=1, col=2)
fig.update_layout(height=300, margin=dict(l=0, r=0, b=0, t=0))
fig.show()
def show_predictions(results: list[tuple]) -> None:
fig = make_subplots(rows=5, cols=len(results), vertical_spacing=0.025)
for i, (inputs, predictions, labels) in enumerate(results, start=1):
fig.add_trace(
graph_objects.Image(z=render_goes16(inputs[:, :, 35:51])), row=1, col=i
)
fig.add_trace(
graph_objects.Image(z=render_gpm(inputs[:, :, 2:3])), row=2, col=i
)
fig.add_trace(
graph_objects.Image(z=render_elevation(inputs[:, :, 51:52])), row=3, col=i
)
fig.add_trace(
graph_objects.Image(z=render_gpm(predictions[:, :, 0:1])), row=4, col=i
)
fig.add_trace(
graph_objects.Image(z=render_gpm(labels[:, :, 0:1])), row=5, col=i
)
fig.update_layout(
height=5 * int(1000 / len(results)),
margin=dict(l=0, r=0, b=0, t=0),
)
fig.show()