in src/flint/torch_util.py [0:0]
def seq2seq_att(mems, lengths, state, att_net=None):
"""
:param mems: [B, T, D_mem] This are the memories.
I call memory for this variable because I think attention is just like read something and then
make alignments with your memories.
This memory here is usually the input hidden state of the encoder.
:param lengths: [B]
:param state: [B, D_state]
I call state for this variable because it's the state I percepts at this time step.
:param att_net: This is the attention network that will be used to calculate the alignment score between
state and memories.
input of the att_net is mems and state with shape:
mems: [exB, D_mem]
state: [exB, D_state]
return of the att_net is [exB, 1]
So any function that map a vector to a scalar could work.
:return: [B, D_result]
"""
d_state = state.size(1)
if not att_net:
return state
else:
batch_list_mems = []
batch_list_state = []
for i, l in enumerate(lengths):
b_mems = mems[i, :l] # [T, D_mem]
batch_list_mems.append(b_mems)
b_state = state[i].expand(b_mems.size(0), d_state) # [T, D_state]
batch_list_state.append(b_state)
packed_sequence_mems = torch.cat(batch_list_mems, 0) # [sum(l), D_mem]
packed_sequence_state = torch.cat(batch_list_state, 0) # [sum(l), D_state]
align_score = att_net(packed_sequence_mems, packed_sequence_state) # [sum(l), 1]
# The score grouped as [(a1, a2, a3), (a1, a2), (a1, a2, a3, a4)].
# aligned_seq = packed_sequence_mems * align_score
start = 0
result_list = []
for i, l in enumerate(lengths):
end = start + l
b_mems = packed_sequence_mems[start:end, :] # [l, D_mems]
b_score = align_score[start:end, :] # [l, 1]
softed_b_score = F.softmax(b_score.transpose(0, 1)).transpose(0, 1) # [l, 1]
weighted_sum = torch.sum(b_mems * softed_b_score, dim=0, keepdim=False) # [D_mems]
result_list.append(weighted_sum)
start = end
result = torch.stack(result_list, dim=0)
return result