in models/vrnn_hier.py [0:0]
def posterior(self, emb, use_mean=False, scale_var=1.):
dists = []
for net_idx in range(len(self.posterior_nets)):
# print('[NET] PosteriorNet {}'.format(net_idx))
# Find the corresopnding activations
ctx_idx = self.arch['latent']['ctx_idx'][net_idx]
out = emb[ctx_idx][:, self.n_ctx:].contiguous()
cur_ctx = emb[ctx_idx][:, :self.n_ctx].contiguous()
branch_layers = self.posterior_nets[net_idx]
# print('CTX IDX: ', ctx_idx, ' shape: ', cur_ctx.shape)
# print(branch_layers)
# Process the current branch
for branch_layer_idx, layer in enumerate(branch_layers):
# print('[NET] PosteriorNet {}/{}'.format(net_idx, branch_layer_idx))
if isinstance(layer, ConvLSTM):
# Get initial condition
cur_ctx = cur_ctx.view(
cur_ctx.shape[0],
-1,
cur_ctx.shape[-2],
cur_ctx.shape[-1]
)
cur_ctx = cur_ctx.unsqueeze(1)
cur_ctx = self.posterior_init_nets[net_idx](cur_ctx)
cur_ctx = cur_ctx.squeeze(1)
# Forward LSTM
out = layer(out, torch.chunk(cur_ctx, 2, 1))
# Dense connectivity latent
elif branch_layer_idx == 3:
# print('THIRD LAYER')
# Get current latent resolution
cur_res = self.arch['latent']['resolution'][net_idx]
# Accumulate previous z
prev_zs = []
for prev_z_idx in range(net_idx):
# Get previous z resolution
prev_res = self.arch['latent']['resolution'][prev_z_idx]
# Compute scaling factor
scaling_factor = cur_res//prev_res
# Interpolate previous z
z_prev = dists[prev_z_idx][-2]
b, t, c, h, w = z_prev.shape
z_prev = z_prev.view(b*t, c, h, w)
z_prev = F.interpolate(z_prev, scale_factor=scaling_factor)
z_prev = z_prev.view(b, t, c, z_prev.shape[-2], z_prev.shape[-1])
prev_zs.append(z_prev)
# Concatenate zs
prev_zs = torch.cat(prev_zs + [out], 2)
# Forward through layer
out = layer(prev_zs)
else:
out = layer(out)
# Compute distribution stats
mean, var = torch.chunk(out, 2, 2)
# Scale the variance
var = var*scale_var
# Softplus var
logvar = F.softplus(var).log()
# Generate sample from this distribution
z0 = flows.gaussian_rsample(mean, logvar, use_mean=use_mean)
dists.append([mean, logvar, z0, z0, None])
return dists