in levit_c.py [0:0]
def __init__(self, in_dim, out_dim, key_dim, num_heads=8,
attn_ratio=2,
activation=None,
stride=2,
resolution=14, resolution_=7):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * self.num_heads
self.attn_ratio = attn_ratio
self.resolution_ = resolution_
self.resolution_2 = resolution_**2
h = self.dh + nh_kd
self.kv = Conv2d_BN(in_dim, h, resolution=resolution)
self.q = torch.nn.Sequential(
torch.nn.AvgPool2d(1, stride, 0),
Conv2d_BN(in_dim, nh_kd, resolution=resolution_))
self.proj = torch.nn.Sequential(
activation(), Conv2d_BN(self.d * num_heads, out_dim, resolution=resolution_))
self.stride = stride
self.resolution = resolution
points = list(itertools.product(range(resolution), range(resolution)))
points_ = list(itertools.product(
range(resolution_), range(resolution_)))
N = len(points)
N_ = len(points_)
attention_offsets = {}
idxs = []
for p1 in points_:
for p2 in points:
size = 1
offset = (
abs(p1[0] * stride - p2[0] + (size - 1) / 2),
abs(p1[1] * stride - p2[1] + (size - 1) / 2))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(
torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs',
torch.LongTensor(idxs).view(N_, N))
global FLOPS_COUNTER
#queries * keys
FLOPS_COUNTER += num_heads * \
(resolution**2) * (resolution_**2) * key_dim
# softmax
FLOPS_COUNTER += num_heads * (resolution**2) * (resolution_**2)
#attention * v
FLOPS_COUNTER += num_heads * \
(resolution**2) * (resolution_**2) * self.d