in rat-sql-gap/seq2struct/models/spider/spider_enc.py [0:0]
def forward(self, descs):
# Encode the question
# - LookupEmbeddings
# - Transform embeddings wrt each other?
# q_enc: PackedSequencePlus, [batch, question len, recurrent_size]
qs = [[desc['question']] for desc in descs]
q_enc, _ = self.question_encoder(qs)
# Encode the columns
# - LookupEmbeddings
# - Transform embeddings wrt each other?
# - Summarize each column into one?
# c_enc: PackedSequencePlus, [batch, sum of column lens, recurrent_size]
c_enc, c_boundaries = self.column_encoder([desc['columns'] for desc in descs])
column_pointer_maps = [
{
i: list(range(left, right))
for i, (left, right) in enumerate(zip(c_boundaries_for_item, c_boundaries_for_item[1:]))
}
for batch_idx, c_boundaries_for_item in enumerate(c_boundaries)
]
# Encode the tables
# - LookupEmbeddings
# - Transform embeddings wrt each other?
# - Summarize each table into one?
# t_enc: PackedSequencePlus, [batch, sum of table lens, recurrent_size]
t_enc, t_boundaries = self.table_encoder([desc['tables'] for desc in descs])
# c_enc_lengths = list(c_enc.orig_lengths())
# table_pointer_maps = [
# {
# i: [
# idx
# for col in desc['table_to_columns'][str(i)]
# for idx in column_pointer_maps[batch_idx][col]
# ] + list(range(left + c_enc_lengths[batch_idx], right + c_enc_lengths[batch_idx]))
# for i, (left, right) in enumerate(zip(t_boundaries_for_item, t_boundaries_for_item[1:]))
# }
# for batch_idx, (desc, t_boundaries_for_item) in enumerate(zip(descs, t_boundaries))
# ]
# directly point to the the table
table_pointer_maps = [
{
i: list(range(left, right))
for i, (left, right) in enumerate(zip(t_boundaries_for_item, t_boundaries_for_item[1:]))
}
for batch_idx, (desc, t_boundaries_for_item) in enumerate(zip(descs, t_boundaries))
]
# Update each other using self-attention
# q_enc_new, c_enc_new, and t_enc_new are PackedSequencePlus with shape
# batch (=1) x length x recurrent_size
if self.batch_encs_update:
q_enc_new, c_enc_new, t_enc_new = self.encs_update(
descs, q_enc, c_enc, c_boundaries, t_enc, t_boundaries)
result = []
for batch_idx, desc in enumerate(descs):
if self.batch_encs_update:
q_enc_new_item = q_enc_new.select(batch_idx).unsqueeze(0)
c_enc_new_item = c_enc_new.select(batch_idx).unsqueeze(0)
t_enc_new_item = t_enc_new.select(batch_idx).unsqueeze(0)
else:
q_enc_new_item, c_enc_new_item, t_enc_new_item, align_mat_item = \
self.encs_update.forward_unbatched(
desc,
q_enc.select(batch_idx).unsqueeze(1),
c_enc.select(batch_idx).unsqueeze(1),
c_boundaries[batch_idx],
t_enc.select(batch_idx).unsqueeze(1),
t_boundaries[batch_idx])
memory = []
words_for_copying = []
if 'question' in self.include_in_memory:
memory.append(q_enc_new_item)
if 'question_for_copying' in desc:
assert q_enc_new_item.shape[1] == len(desc['question_for_copying'])
words_for_copying += desc['question_for_copying']
else:
words_for_copying += [''] * q_enc_new_item.shape[1]
if 'column' in self.include_in_memory:
memory.append(c_enc_new_item)
words_for_copying += [''] * c_enc_new_item.shape[1]
if 'table' in self.include_in_memory:
memory.append(t_enc_new_item)
words_for_copying += [''] * t_enc_new_item.shape[1]
memory = torch.cat(memory, dim=1)
result.append(SpiderEncoderState(
state=None,
memory=memory,
question_memory=q_enc_new_item,
schema_memory=torch.cat((c_enc_new_item, t_enc_new_item), dim=1),
# TODO: words should match memory
words=words_for_copying,
pointer_memories={
'column': c_enc_new_item,
'table': torch.cat((c_enc_new_item, t_enc_new_item), dim=1),
},
pointer_maps={
'column': column_pointer_maps[batch_idx],
'table': table_pointer_maps[batch_idx],
},
m2c_align_mat=align_mat_item[0],
m2t_align_mat=align_mat_item[1],
))
return result