examples/mixed_hmm/seal_data.py (43 lines of code) (raw):
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import os
from urllib.request import urlopen
import pandas as pd
import torch
MISSING = 1e-6
def download_seal_data(filename):
"""download the preprocessed seal data and save it to filename"""
url = "https://d2hg8soec8ck9v.cloudfront.net/datasets/prep_seal_data.csv"
with open(filename, "wb") as f:
f.write(urlopen(url).read())
def prepare_seal(filename, random_effects):
if not os.path.exists(filename):
download_seal_data(filename)
seal_df = pd.read_csv(filename)
obs_keys = ["step", "angle", "omega"]
# data format for z1, z2:
# single tensor with shape (individual, group, time, coords)
observations = torch.zeros((20, 2, 1800, len(obs_keys))).fill_(float("-inf"))
for g, (group, group_df) in enumerate(seal_df.groupby("sex")):
for i, (ind, ind_df) in enumerate(group_df.groupby("ID")):
for o, obs_key in enumerate(obs_keys):
observations[i, g, 0:len(ind_df), o] = torch.tensor(ind_df[obs_key].values)
observations[torch.isnan(observations)] = float("-inf")
# make masks
# mask_i should mask out individuals, it applies at all timesteps
mask_i = (observations > float("-inf")).any(dim=-1).any(dim=-1) # time nonempty
# mask_t handles padding for time series of different length
mask_t = (observations > float("-inf")).all(dim=-1) # include non-inf
# temporary hack to avoid zero-inflation issues
# observations[observations == 0.] = MISSING
observations[(observations == 0.) | (observations == float("-inf"))] = MISSING
assert not torch.isnan(observations).any()
# observations = observations[..., 5:11, :] # truncate for testing
config = {
"MISSING": MISSING,
"sizes": {
"state": 3,
"random": 4,
"group": observations.shape[1],
"individual": observations.shape[0],
"timesteps": observations.shape[2],
},
"group": {"random": random_effects["group"], "fixed": None},
"individual": {"random": random_effects["individual"], "fixed": None, "mask": mask_i},
"timestep": {"random": None, "fixed": None, "mask": mask_t},
"observations": {
"step": observations[..., 0],
"angle": observations[..., 1],
"omega": observations[..., 2],
},
}
return config