examples/air/modules.py (57 lines of code) (raw):

# Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 import torch import torch.nn as nn from torch.nn.functional import softplus # Takes pixel intensities of the attention window to parameters (mean, # standard deviation) of the distribution over the latent code, # z_what. class Encoder(nn.Module): def __init__(self, x_size, h_sizes, z_size, non_linear_layer): super().__init__() self.z_size = z_size output_size = 2 * z_size self.mlp = MLP(x_size, h_sizes + [output_size], non_linear_layer) def forward(self, x): a = self.mlp(x) return a[:, 0:self.z_size], softplus(a[:, self.z_size:]) # Takes a latent code, z_what, to pixel intensities. class Decoder(nn.Module): def __init__(self, x_size, h_sizes, z_size, bias, use_sigmoid, non_linear_layer): super().__init__() self.bias = bias self.use_sigmoid = use_sigmoid self.mlp = MLP(z_size, h_sizes + [x_size], non_linear_layer) def forward(self, z): a = self.mlp(z) if self.bias is not None: a = a + self.bias return torch.sigmoid(a) if self.use_sigmoid else a # A general purpose module to construct networks that look like: # [Linear (256 -> 1)] # [Linear (256 -> 256), ReLU (), Linear (256 -> 1)] # [Linear (256 -> 256), ReLU (), Linear (256 -> 1), ReLU ()] # etc. class MLP(nn.Module): def __init__(self, in_size, out_sizes, non_linear_layer, output_non_linearity=False): super().__init__() assert len(out_sizes) >= 1 layers = [] in_sizes = [in_size] + out_sizes[0:-1] sizes = list(zip(in_sizes, out_sizes)) for (i, o) in sizes[0:-1]: layers.append(nn.Linear(i, o)) layers.append(non_linear_layer()) layers.append(nn.Linear(sizes[-1][0], sizes[-1][1])) if output_non_linearity: layers.append(non_linear_layer()) self.seq = nn.Sequential(*layers) def forward(self, x): return self.seq(x) # Takes the guide RNN hidden state to parameters of the guide # distributions over z_where and z_pres. class Predict(nn.Module): def __init__(self, input_size, h_sizes, z_pres_size, z_where_size, non_linear_layer): super().__init__() self.z_pres_size = z_pres_size self.z_where_size = z_where_size output_size = z_pres_size + 2 * z_where_size self.mlp = MLP(input_size, h_sizes + [output_size], non_linear_layer) def forward(self, h): out = self.mlp(h) z_pres_p = torch.sigmoid(out[:, 0:self.z_pres_size]) z_where_loc = out[:, self.z_pres_size:self.z_pres_size + self.z_where_size] z_where_scale = softplus(out[:, (self.z_pres_size + self.z_where_size):]) return z_pres_p, z_where_loc, z_where_scale class Identity(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x