in sing/sequence/models.py [0:0]
def forward(self, start=0, length=None, hidden=None, **tensors):
"""
Arguments:
start (int): first time step to generate
length (int): length of the sequence to generate. If `None`,
will be taken to be `self.length - start`
hidden ((torch.FloatTensor, torch.FloatTensor)):
hidden state of the LSTM or `None` to start
from a blank one
**tensors (dict[str, torch.LongTensor]):
dictionary containing the tensors used as inputs
to the lookup tables specified by the `embeddings`
parameter of the constructor
"""
length = self.length - start if length is None else length
inputs = []
for name, table in zip(self.inputs, self.tables):
value = tensors[name].transpose(0, 1)
embedding = table.forward(value)
inputs.append(embedding.expand(length, -1, -1))
reference = inputs[0]
if self.time_table is not None:
times = torch.arange(
start, start + length,
device=reference.device).view(-1, 1).expand(
-1, reference.size(1))
inputs.append(self.time_table.forward(times))
input = torch.cat(inputs, dim=-1)
if hidden is not None:
hidden = [h.transpose(0, 1).contiguous() for h in hidden]
self.lstm.flatten_parameters()
output, hidden = self.lstm.forward(input, hidden)
decoded = self.decoder(output.view(-1, output.size(-1))).view(
output.size(0), output.size(1), -1)
hidden = [h.transpose(0, 1) for h in hidden]
return decoded.transpose(0, 1).transpose(1, 2), hidden