in models/s2s_big_hier_v4.py [0:0]
def prior(self, x, ctx, q_dists, use_prior, use_mean=False, scale_var=1.):
dists = []
sto_branches = sorted(self.sto_branches.keys(), reverse=True)
for layer_idx in sto_branches:
# print(layer_idx)
# Find the corresopnding activations
out = x[layer_idx][:, self.n_ctx - 1: -1].contiguous()
cur_ctx = ctx[layer_idx][:, :self.n_ctx].contiguous()
branch_layers = self.prior_branches['layer_{}'.format(layer_idx)]
# Process the current branch
for branch_layer_idx, layer in enumerate(branch_layers):
# print(branch_layer_idx)
if isinstance(layer, layers.ConvLSTM):
# Get initial condition
cur_ctx = cur_ctx.view(cur_ctx.shape[0], -1, cur_ctx.shape[-2], cur_ctx.shape[-1])
cur_ctx = cur_ctx.unsqueeze(1)
cur_ctx = self.prior_init_nets['layer_{}'.format(layer_idx)](cur_ctx)
cur_ctx = cur_ctx.squeeze(1)
# Forward LSTM
out = layer(out, torch.chunk(cur_ctx, 2, 1))
# Handcrafted rules for integrating the different z's
elif branch_layer_idx == 3:
if layer_idx == 16:
out = layer(out)
elif layer_idx == 10:
if use_prior:
z1 = dists[0][-2]
else:
z1 = q_dists[0][-2]
b, t, c, h, w = z1.shape
z1 = z1.view(b*t, c, h, w)
z1 = F.interpolate(z1, scale_factor=8)
z1 = z1.view(b, t, c, z1.shape[-2], z1.shape[-1])
out = torch.cat([out, z1], 2)
out = layer(out)
elif layer_idx == 4:
if use_prior:
z1 = dists[0][-2]
z2 = dists[1][-2]
else:
z1 = q_dists[0][-2]
z2 = q_dists[1][-2]
b, t, c, h, w = z1.shape
z1 = z1.view(b*t, c, h, w)
z1 = F.interpolate(z1, scale_factor=32)
z1 = z1.view(b, t, c, z1.shape[-2], z1.shape[-1])
b, t, c, h, w = z2.shape
z2 = z2.view(b*t, c, h, w)
z2 = F.interpolate(z2, scale_factor=4)
z2 = z2.view(b, t, c, z2.shape[-2], z2.shape[-1])
out = torch.cat([out, z1, z2], 2)
out = layer(out)
else:
out = layer(out)
# Compute distribution stats
mean, var = torch.chunk(out, 2, 2)
# Softplus var
scaled_var = F.softplus(var)*scale_var
logvar = scaled_var.log()
# Generate sample from this distribution
z0 = flows.gaussian_rsample(mean, logvar, use_mean=use_mean)
dists.append([mean, logvar, z0, z0, None])
return dists