in code/experiment_synthetic/sem.py [0:0]
def __call__(self, n, env):
h = torch.randn(n, self.dim) * env
x = h @ self.whx + torch.randn(n, self.dim) * env
if self.hetero:
y = x @ self.wxy + h @ self.why + torch.randn(n, self.dim) * env
z = y @ self.wyz + h @ self.whz + torch.randn(n, self.dim)
else:
y = x @ self.wxy + h @ self.why + torch.randn(n, self.dim)
z = y @ self.wyz + h @ self.whz + torch.randn(n, self.dim) * env
return torch.cat((x, z), 1) @ self.scramble, y.sum(1, keepdim=True)