def extract_selfattention_maps()

in codes/attention.py [0:0]


def extract_selfattention_maps(transformer_encoder, x):
    attn_logits_maps = []
    attn_probs_maps = []
    num_layers = transformer_encoder.num_layers
    d_model = transformer_encoder.layers[0].self_attn.embed_dim
    num_heads = transformer_encoder.layers[0].self_attn.num_heads
    norm_first = transformer_encoder.layers[0].norm_first
    i=0
    h = x.clone()
    if norm_first:
        h = transformer_encoder.layers[i].norm1(h)
    attn_probs = compute_selfattention(transformer_encoder, h, i, d_model, num_heads)
    x = transformer_encoder.layers[i](x)
    
    return attn_probs