examples/air/air.py (227 lines of code) (raw):
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
"""
An implementation of the model described in [1].
[1] Eslami, SM Ali, et al. "Attend, infer, repeat: Fast scene
understanding with generative models." Advances in Neural Information
Processing Systems. 2016.
"""
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import pyro
import pyro.distributions as dist
from modules import MLP, Decoder, Encoder, Identity, Predict
# Default prior success probability for z_pres.
def default_z_pres_prior_p(t):
return 0.5
ModelState = namedtuple('ModelState', ['x', 'z_pres', 'z_where'])
GuideState = namedtuple('GuideState', ['h', 'c', 'bl_h', 'bl_c', 'z_pres', 'z_where', 'z_what'])
class AIR(nn.Module):
def __init__(self,
num_steps,
x_size,
window_size,
z_what_size,
rnn_hidden_size,
encoder_net=[],
decoder_net=[],
predict_net=[],
embed_net=None,
bl_predict_net=[],
non_linearity='ReLU',
decoder_output_bias=None,
decoder_output_use_sigmoid=False,
use_masking=True,
use_baselines=True,
baseline_scalar=None,
scale_prior_mean=3.0,
scale_prior_sd=0.1,
pos_prior_mean=0.0,
pos_prior_sd=1.0,
likelihood_sd=0.3,
use_cuda=False):
super().__init__()
self.num_steps = num_steps
self.x_size = x_size
self.window_size = window_size
self.z_what_size = z_what_size
self.rnn_hidden_size = rnn_hidden_size
self.use_masking = use_masking
self.use_baselines = use_baselines
self.baseline_scalar = baseline_scalar
self.likelihood_sd = likelihood_sd
self.use_cuda = use_cuda
prototype = torch.tensor(0.).cuda() if use_cuda else torch.tensor(0.)
self.options = dict(dtype=prototype.dtype, device=prototype.device)
self.z_pres_size = 1
self.z_where_size = 3
# By making these parameters they will be moved to the gpu
# when necessary. (They are not registered with pyro for
# optimization.)
self.z_where_loc_prior = nn.Parameter(
torch.FloatTensor([scale_prior_mean, pos_prior_mean, pos_prior_mean]),
requires_grad=False)
self.z_where_scale_prior = nn.Parameter(
torch.FloatTensor([scale_prior_sd, pos_prior_sd, pos_prior_sd]),
requires_grad=False)
# Create nn modules.
rnn_input_size = x_size ** 2 if embed_net is None else embed_net[-1]
rnn_input_size += self.z_where_size + z_what_size + self.z_pres_size
nl = getattr(nn, non_linearity)
self.rnn = nn.LSTMCell(rnn_input_size, rnn_hidden_size)
self.encode = Encoder(window_size ** 2, encoder_net, z_what_size, nl)
self.decode = Decoder(window_size ** 2, decoder_net, z_what_size,
decoder_output_bias, decoder_output_use_sigmoid, nl)
self.predict = Predict(rnn_hidden_size, predict_net, self.z_pres_size, self.z_where_size, nl)
self.embed = Identity() if embed_net is None else MLP(x_size ** 2, embed_net, nl, True)
self.bl_rnn = nn.LSTMCell(rnn_input_size, rnn_hidden_size)
self.bl_predict = MLP(rnn_hidden_size, bl_predict_net + [1], nl)
self.bl_embed = Identity() if embed_net is None else MLP(x_size ** 2, embed_net, nl, True)
# Create parameters.
self.h_init = nn.Parameter(torch.zeros(1, rnn_hidden_size))
self.c_init = nn.Parameter(torch.zeros(1, rnn_hidden_size))
self.bl_h_init = nn.Parameter(torch.zeros(1, rnn_hidden_size))
self.bl_c_init = nn.Parameter(torch.zeros(1, rnn_hidden_size))
self.z_where_init = nn.Parameter(torch.zeros(1, self.z_where_size))
self.z_what_init = nn.Parameter(torch.zeros(1, self.z_what_size))
if use_cuda:
self.cuda()
def prior(self, n, **kwargs):
state = ModelState(
x=torch.zeros(n, self.x_size, self.x_size, **self.options),
z_pres=torch.ones(n, self.z_pres_size, **self.options),
z_where=None)
z_pres = []
z_where = []
for t in range(self.num_steps):
state = self.prior_step(t, n, state, **kwargs)
z_where.append(state.z_where)
z_pres.append(state.z_pres)
return (z_where, z_pres), state.x
def prior_step(self, t, n, prev, z_pres_prior_p=default_z_pres_prior_p):
# Sample presence indicators.
z_pres = pyro.sample('z_pres_{}'.format(t),
dist.Bernoulli(z_pres_prior_p(t) * prev.z_pres)
.to_event(1))
# If zero is sampled for a data point, then no more objects
# will be added to its output image. We can't
# straight-forwardly avoid generating further objects, so
# instead we zero out the log_prob_sum of future choices.
sample_mask = z_pres if self.use_masking else torch.tensor(1.0)
# Sample attention window position.
z_where = pyro.sample('z_where_{}'.format(t),
dist.Normal(self.z_where_loc_prior.expand(n, self.z_where_size),
self.z_where_scale_prior.expand(n, self.z_where_size))
.mask(sample_mask)
.to_event(1))
# Sample latent code for contents of the attention window.
z_what = pyro.sample('z_what_{}'.format(t),
dist.Normal(torch.zeros(n, self.z_what_size, **self.options),
torch.ones(n, self.z_what_size, **self.options))
.mask(sample_mask)
.to_event(1))
# Map latent code to pixel space.
y_att = self.decode(z_what)
# Position/scale attention window within larger image.
y = window_to_image(z_where, self.window_size, self.x_size, y_att)
# Combine the image generated at this step with the image so far.
# (Note that there's no notion of occlusion here. Overlapping
# objects can create pixel intensities > 1.)
x = prev.x + (y * z_pres.view(-1, 1, 1))
return ModelState(x=x, z_pres=z_pres, z_where=z_where)
def model(self, data, batch_size, **kwargs):
pyro.module("decode", self.decode)
with pyro.plate('data', data.size(0), device=data.device) as ix:
batch = data[ix]
n = batch.size(0)
(z_where, z_pres), x = self.prior(n, **kwargs)
pyro.sample('obs',
dist.Normal(x.view(n, -1),
(self.likelihood_sd * torch.ones(n, self.x_size ** 2, **self.options)))
.to_event(1),
obs=batch.view(n, -1))
def guide(self, data, batch_size, **kwargs):
pyro.module('rnn', self.rnn),
pyro.module('predict', self.predict),
pyro.module('encode', self.encode),
pyro.module('embed', self.embed),
pyro.module('bl_rnn', self.bl_rnn),
pyro.module('bl_predict', self.bl_predict),
pyro.module('bl_embed', self.bl_embed)
pyro.param('h_init', self.h_init)
pyro.param('c_init', self.c_init)
pyro.param('z_where_init', self.z_where_init)
pyro.param('z_what_init', self.z_what_init)
pyro.param('bl_h_init', self.bl_h_init)
pyro.param('bl_c_init', self.bl_c_init)
with pyro.plate('data', data.size(0), subsample_size=batch_size, device=data.device) as ix:
batch = data[ix]
n = batch.size(0)
# Embed inputs.
flattened_batch = batch.view(n, -1)
inputs = {
'raw': batch,
'embed': self.embed(flattened_batch),
'bl_embed': self.bl_embed(flattened_batch)
}
# Initial state.
state = GuideState(
h=batch_expand(self.h_init, n),
c=batch_expand(self.c_init, n),
bl_h=batch_expand(self.bl_h_init, n),
bl_c=batch_expand(self.bl_c_init, n),
z_pres=torch.ones(n, self.z_pres_size, **self.options),
z_where=batch_expand(self.z_where_init, n),
z_what=batch_expand(self.z_what_init, n))
z_pres = []
z_where = []
for t in range(self.num_steps):
state = self.guide_step(t, n, state, inputs)
z_where.append(state.z_where)
z_pres.append(state.z_pres)
return z_where, z_pres
def guide_step(self, t, n, prev, inputs):
rnn_input = torch.cat((inputs['embed'], prev.z_where, prev.z_what, prev.z_pres), 1)
h, c = self.rnn(rnn_input, (prev.h, prev.c))
z_pres_p, z_where_loc, z_where_scale = self.predict(h)
# Compute baseline estimates for discrete choice z_pres.
infer_dict, bl_h, bl_c = self.baseline_step(prev, inputs)
# Sample presence.
z_pres = pyro.sample('z_pres_{}'.format(t),
dist.Bernoulli(z_pres_p * prev.z_pres).to_event(1),
infer=infer_dict)
sample_mask = z_pres if self.use_masking else torch.tensor(1.0)
z_where = pyro.sample('z_where_{}'.format(t),
dist.Normal(z_where_loc + self.z_where_loc_prior,
z_where_scale * self.z_where_scale_prior)
.mask(sample_mask)
.to_event(1))
# Figure 2 of [1] shows x_att depending on z_where and h,
# rather than z_where and x as here, but I think this is
# correct.
x_att = image_to_window(z_where, self.window_size, self.x_size, inputs['raw'])
# Encode attention windows.
z_what_loc, z_what_scale = self.encode(x_att)
z_what = pyro.sample('z_what_{}'.format(t),
dist.Normal(z_what_loc, z_what_scale)
.mask(sample_mask)
.to_event(1))
return GuideState(h=h, c=c, bl_h=bl_h, bl_c=bl_c, z_pres=z_pres, z_where=z_where, z_what=z_what)
def baseline_step(self, prev, inputs):
if not self.use_baselines:
return dict(), None, None
# Prevent gradients flowing back from baseline loss to
# inference net by detaching from graph here.
rnn_input = torch.cat((inputs['bl_embed'],
prev.z_where.detach(),
prev.z_what.detach(),
prev.z_pres.detach()), 1)
bl_h, bl_c = self.bl_rnn(rnn_input, (prev.bl_h, prev.bl_c))
bl_value = self.bl_predict(bl_h)
# Zero out values for finished data points. This avoids adding
# superfluous terms to the loss.
if self.use_masking:
bl_value = bl_value * prev.z_pres
# The value that the baseline net is estimating can be very
# large. An option to scale the nets output is provided
# to make it easier for the net to output values of this
# scale.
if self.baseline_scalar is not None:
bl_value = bl_value * self.baseline_scalar
infer_dict = dict(baseline=dict(baseline_value=bl_value.squeeze(-1)))
return infer_dict, bl_h, bl_c
# Spatial transformer helpers.
expansion_indices = torch.LongTensor([1, 0, 2, 0, 1, 3])
def expand_z_where(z_where):
# Take a batch of three-vectors, and massages them into a batch of
# 2x3 matrices with elements like so:
# [s,x,y] -> [[s,0,x],
# [0,s,y]]
n = z_where.size(0)
out = torch.cat((z_where.new_zeros(n, 1), z_where), 1)
ix = expansion_indices
if z_where.is_cuda:
ix = ix.cuda()
out = torch.index_select(out, 1, ix)
out = out.view(n, 2, 3)
return out
# Scaling by `1/scale` here is unsatisfactory, as `scale` could be
# zero.
def z_where_inv(z_where):
# Take a batch of z_where vectors, and compute their "inverse".
# That is, for each row compute:
# [s,x,y] -> [1/s,-x/s,-y/s]
# These are the parameters required to perform the inverse of the
# spatial transform performed in the generative model.
n = z_where.size(0)
out = torch.cat((z_where.new_ones(n, 1), -z_where[:, 1:]), 1)
# Divide all entries by the scale.
out = out / z_where[:, 0:1]
return out
def window_to_image(z_where, window_size, image_size, windows):
n = windows.size(0)
assert windows.size(1) == window_size ** 2, 'Size mismatch.'
theta = expand_z_where(z_where)
grid = F.affine_grid(theta, torch.Size((n, 1, image_size, image_size)))
out = F.grid_sample(windows.view(n, 1, window_size, window_size), grid)
return out.view(n, image_size, image_size)
def image_to_window(z_where, window_size, image_size, images):
n = images.size(0)
assert images.size(1) == images.size(2) == image_size, 'Size mismatch.'
theta_inv = expand_z_where(z_where_inv(z_where))
grid = F.affine_grid(theta_inv, torch.Size((n, 1, window_size, window_size)))
out = F.grid_sample(images.view(n, 1, image_size, image_size), grid)
return out.view(n, -1)
# Helper to expand parameters to the size of the mini-batch. I would
# like to remove this and just write `t.expand(n, -1)` inline, but the
# `-1` argument of `expand` doesn't seem to work with PyTorch 0.2.0.
def batch_expand(t, n):
return t.expand(n, t.size(1))
# Combine z_pres and z_where (as returned by the model and guide) into
# a single tensor, with size:
# [batch_size, num_steps, z_where_size + z_pres_size]
def latents_to_tensor(z):
return torch.stack([
torch.cat((z_where.cpu().data, z_pres.cpu().data), 1)
for z_where, z_pres in zip(*z)]).transpose(0, 1)