examples/air/modules.py (57 lines of code) (raw):
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
from torch.nn.functional import softplus
# Takes pixel intensities of the attention window to parameters (mean,
# standard deviation) of the distribution over the latent code,
# z_what.
class Encoder(nn.Module):
def __init__(self, x_size, h_sizes, z_size, non_linear_layer):
super().__init__()
self.z_size = z_size
output_size = 2 * z_size
self.mlp = MLP(x_size, h_sizes + [output_size], non_linear_layer)
def forward(self, x):
a = self.mlp(x)
return a[:, 0:self.z_size], softplus(a[:, self.z_size:])
# Takes a latent code, z_what, to pixel intensities.
class Decoder(nn.Module):
def __init__(self, x_size, h_sizes, z_size, bias, use_sigmoid, non_linear_layer):
super().__init__()
self.bias = bias
self.use_sigmoid = use_sigmoid
self.mlp = MLP(z_size, h_sizes + [x_size], non_linear_layer)
def forward(self, z):
a = self.mlp(z)
if self.bias is not None:
a = a + self.bias
return torch.sigmoid(a) if self.use_sigmoid else a
# A general purpose module to construct networks that look like:
# [Linear (256 -> 1)]
# [Linear (256 -> 256), ReLU (), Linear (256 -> 1)]
# [Linear (256 -> 256), ReLU (), Linear (256 -> 1), ReLU ()]
# etc.
class MLP(nn.Module):
def __init__(self, in_size, out_sizes, non_linear_layer, output_non_linearity=False):
super().__init__()
assert len(out_sizes) >= 1
layers = []
in_sizes = [in_size] + out_sizes[0:-1]
sizes = list(zip(in_sizes, out_sizes))
for (i, o) in sizes[0:-1]:
layers.append(nn.Linear(i, o))
layers.append(non_linear_layer())
layers.append(nn.Linear(sizes[-1][0], sizes[-1][1]))
if output_non_linearity:
layers.append(non_linear_layer())
self.seq = nn.Sequential(*layers)
def forward(self, x):
return self.seq(x)
# Takes the guide RNN hidden state to parameters of the guide
# distributions over z_where and z_pres.
class Predict(nn.Module):
def __init__(self, input_size, h_sizes, z_pres_size, z_where_size, non_linear_layer):
super().__init__()
self.z_pres_size = z_pres_size
self.z_where_size = z_where_size
output_size = z_pres_size + 2 * z_where_size
self.mlp = MLP(input_size, h_sizes + [output_size], non_linear_layer)
def forward(self, h):
out = self.mlp(h)
z_pres_p = torch.sigmoid(out[:, 0:self.z_pres_size])
z_where_loc = out[:, self.z_pres_size:self.z_pres_size + self.z_where_size]
z_where_scale = softplus(out[:, (self.z_pres_size + self.z_where_size):])
return z_pres_p, z_where_loc, z_where_scale
class Identity(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x