def seq2seq_att()

in src/flint/torch_util.py [0:0]


def seq2seq_att(mems, lengths, state, att_net=None):
    """
    :param mems: [B, T, D_mem] This are the memories.
                    I call memory for this variable because I think attention is just like read something and then
                    make alignments with your memories.
                    This memory here is usually the input hidden state of the encoder.

    :param lengths: [B]
    :param state: [B, D_state]
                    I call state for this variable because it's the state I percepts at this time step.

    :param att_net: This is the attention network that will be used to calculate the alignment score between
                    state and memories.
                    input of the att_net is mems and state with shape:
                        mems: [exB, D_mem]
                        state: [exB, D_state]
                    return of the att_net is [exB, 1]

                    So any function that map a vector to a scalar could work.

    :return: [B, D_result]
    """

    d_state = state.size(1)

    if not att_net:
        return state
    else:
        batch_list_mems = []
        batch_list_state = []
        for i, l in enumerate(lengths):
            b_mems = mems[i, :l]  # [T, D_mem]
            batch_list_mems.append(b_mems)

            b_state = state[i].expand(b_mems.size(0), d_state)  # [T, D_state]
            batch_list_state.append(b_state)

        packed_sequence_mems = torch.cat(batch_list_mems, 0)  # [sum(l), D_mem]
        packed_sequence_state = torch.cat(batch_list_state, 0)  # [sum(l), D_state]

        align_score = att_net(packed_sequence_mems, packed_sequence_state)  # [sum(l), 1]

        # The score grouped as [(a1, a2, a3), (a1, a2), (a1, a2, a3, a4)].
        # aligned_seq = packed_sequence_mems * align_score

        start = 0
        result_list = []
        for i, l in enumerate(lengths):
            end = start + l

            b_mems = packed_sequence_mems[start:end, :]  # [l, D_mems]
            b_score = align_score[start:end, :]  # [l, 1]

            softed_b_score = F.softmax(b_score.transpose(0, 1)).transpose(0, 1)  # [l, 1]

            weighted_sum = torch.sum(b_mems * softed_b_score, dim=0, keepdim=False)  # [D_mems]

            result_list.append(weighted_sum)

            start = end

        result = torch.stack(result_list, dim=0)

        return result