def _inference()

in video_generation.py [0:0]


    def _inference(self, inp: str, out: str):
        print(f"Generating attention images to {out}")

        for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))):
            with open(img_path, "rb") as f:
                img = Image.open(f)
                img = img.convert("RGB")

            if self.args.resize is not None:
                transform = pth_transforms.Compose(
                    [
                        pth_transforms.ToTensor(),
                        pth_transforms.Resize(self.args.resize),
                        pth_transforms.Normalize(
                            (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
                        ),
                    ]
                )
            else:
                transform = pth_transforms.Compose(
                    [
                        pth_transforms.ToTensor(),
                        pth_transforms.Normalize(
                            (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
                        ),
                    ]
                )

            img = transform(img)

            # make the image divisible by the patch size
            w, h = (
                img.shape[1] - img.shape[1] % self.args.patch_size,
                img.shape[2] - img.shape[2] % self.args.patch_size,
            )
            img = img[:, :w, :h].unsqueeze(0)

            w_featmap = img.shape[-2] // self.args.patch_size
            h_featmap = img.shape[-1] // self.args.patch_size

            attentions = self.model.get_last_selfattention(img.to(DEVICE))

            nh = attentions.shape[1]  # number of head

            # we keep only the output patch attention
            attentions = attentions[0, :, 0, 1:].reshape(nh, -1)

            # we keep only a certain percentage of the mass
            val, idx = torch.sort(attentions)
            val /= torch.sum(val, dim=1, keepdim=True)
            cumval = torch.cumsum(val, dim=1)
            th_attn = cumval > (1 - self.args.threshold)
            idx2 = torch.argsort(idx)
            for head in range(nh):
                th_attn[head] = th_attn[head][idx2[head]]
            th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
            # interpolate
            th_attn = (
                nn.functional.interpolate(
                    th_attn.unsqueeze(0),
                    scale_factor=self.args.patch_size,
                    mode="nearest",
                )[0]
                .cpu()
                .numpy()
            )

            attentions = attentions.reshape(nh, w_featmap, h_featmap)
            attentions = (
                nn.functional.interpolate(
                    attentions.unsqueeze(0),
                    scale_factor=self.args.patch_size,
                    mode="nearest",
                )[0]
                .cpu()
                .numpy()
            )

            # save attentions heatmaps
            fname = os.path.join(out, "attn-" + os.path.basename(img_path))
            plt.imsave(
                fname=fname,
                arr=sum(
                    attentions[i] * 1 / attentions.shape[0]
                    for i in range(attentions.shape[0])
                ),
                cmap="inferno",
                format="jpg",
            )