in code/experiment_synthetic/sem.py [0:0]
def __init__(self, dim, ones=True, scramble=False, hetero=True, hidden=False):
self.hetero = hetero
self.hidden = hidden
self.dim = dim // 2
if ones:
self.wxy = torch.eye(self.dim)
self.wyz = torch.eye(self.dim)
else:
self.wxy = torch.randn(self.dim, self.dim) / dim
self.wyz = torch.randn(self.dim, self.dim) / dim
if scramble:
self.scramble, _ = torch.qr(torch.randn(dim, dim))
else:
self.scramble = torch.eye(dim)
if hidden:
self.whx = torch.randn(self.dim, self.dim) / dim
self.why = torch.randn(self.dim, self.dim) / dim
self.whz = torch.randn(self.dim, self.dim) / dim
else:
self.whx = torch.eye(self.dim, self.dim)
self.why = torch.zeros(self.dim, self.dim)
self.whz = torch.zeros(self.dim, self.dim)