in neural/model.py [0:0]
def forward(self, meg, forcings, subject_id):
forcings = dict(forcings)
batch, _, length = meg.size()
inputs = []
mask = self.get_meg_mask(meg, forcings)
meg = meg * mask
inputs += [meg, mask]
if self.subject_embedding is not None:
subject = self.subject_embedding(subject_id)
inputs.append(subject.view(batch, -1, 1).expand(-1, -1, length))
if self.forcing_dims:
_, forcings = zip(*sorted([(k, v)
for k, v in forcings.items() if k in self.forcing_dims]))
else:
forcings = {}
inputs.extend(forcings)
x = th.cat(inputs, dim=1)
x = self.pad(x)
x = self.encoder(x)
if self.lstm is not None:
x = x.permute(2, 0, 1)
x, _ = self.lstm(x)
x = x.permute(1, 2, 0)
out = self.decoder(x)
return center_trim(out, length)