in video_generation.py [0:0]
def _inference(self, inp: str, out: str):
print(f"Generating attention images to {out}")
for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))):
with open(img_path, "rb") as f:
img = Image.open(f)
img = img.convert("RGB")
if self.args.resize is not None:
transform = pth_transforms.Compose(
[
pth_transforms.ToTensor(),
pth_transforms.Resize(self.args.resize),
pth_transforms.Normalize(
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
),
]
)
else:
transform = pth_transforms.Compose(
[
pth_transforms.ToTensor(),
pth_transforms.Normalize(
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
),
]
)
img = transform(img)
# make the image divisible by the patch size
w, h = (
img.shape[1] - img.shape[1] % self.args.patch_size,
img.shape[2] - img.shape[2] % self.args.patch_size,
)
img = img[:, :w, :h].unsqueeze(0)
w_featmap = img.shape[-2] // self.args.patch_size
h_featmap = img.shape[-1] // self.args.patch_size
attentions = self.model.get_last_selfattention(img.to(DEVICE))
nh = attentions.shape[1] # number of head
# we keep only the output patch attention
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
# we keep only a certain percentage of the mass
val, idx = torch.sort(attentions)
val /= torch.sum(val, dim=1, keepdim=True)
cumval = torch.cumsum(val, dim=1)
th_attn = cumval > (1 - self.args.threshold)
idx2 = torch.argsort(idx)
for head in range(nh):
th_attn[head] = th_attn[head][idx2[head]]
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
# interpolate
th_attn = (
nn.functional.interpolate(
th_attn.unsqueeze(0),
scale_factor=self.args.patch_size,
mode="nearest",
)[0]
.cpu()
.numpy()
)
attentions = attentions.reshape(nh, w_featmap, h_featmap)
attentions = (
nn.functional.interpolate(
attentions.unsqueeze(0),
scale_factor=self.args.patch_size,
mode="nearest",
)[0]
.cpu()
.numpy()
)
# save attentions heatmaps
fname = os.path.join(out, "attn-" + os.path.basename(img_path))
plt.imsave(
fname=fname,
arr=sum(
attentions[i] * 1 / attentions.shape[0]
for i in range(attentions.shape[0])
),
cmap="inferno",
format="jpg",
)