in vae.py [0:0]
def get_inputs(self, xs, activations):
acts = activations[self.base]
try:
x = xs[self.base]
except KeyError:
x = torch.zeros_like(acts)
if acts.shape[0] != x.shape[0]:
x = x.repeat(acts.shape[0], 1, 1, 1)
return x, acts