pytorchvideo_trainer/pytorchvideo_trainer/module/byol.py [80:110]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        self._copy_weights_to_backbone_mmt()

    def _copy_weights_to_backbone_mmt(self) -> None:
        dist = {}
        for name, p in self.backbone.named_parameters():
            dist[name] = p
        for name, p in self.backbone_mmt.named_parameters():
            p.data.copy_(dist[name].data)

    @torch.no_grad()
    def momentum_update_backbone(self) -> None:
        """
        Momentum update on the backbone.
        """
        m = self.mmt
        dist = {}
        for name, p in self.backbone.named_parameters():
            dist[name] = p
        for name, p in self.backbone_mmt.named_parameters():
            # pyre-ignore[41]
            p.data = dist[name].data * (1.0 - m) + p.data * m

    @torch.no_grad()
    def forward_backbone_mmt(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward momentum backbone.
        Args:
            x (tensor): input to be forwarded of shape N x C x T x H x W
        """
        with torch.no_grad():
            proj = self.backbone_mmt(x)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



pytorchvideo_trainer/pytorchvideo_trainer/module/moco_v2.py [126:156]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        self._copy_weights_to_backbone_mmt()

    def _copy_weights_to_backbone_mmt(self) -> None:
        dist = {}
        for name, p in self.backbone.named_parameters():
            dist[name] = p
        for name, p in self.backbone_mmt.named_parameters():
            p.data.copy_(dist[name].data)

    @torch.no_grad()
    def momentum_update_backbone(self) -> None:
        """
        Momentum update on the backbone.
        """
        m = self.mmt
        dist = {}
        for name, p in self.backbone.named_parameters():
            dist[name] = p
        for name, p in self.backbone_mmt.named_parameters():
            # pyre-ignore[41]
            p.data = dist[name].data * (1.0 - m) + p.data * m

    @torch.no_grad()
    def forward_backbone_mmt(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward momentum backbone.
        Args:
            x (tensor): input to be forwarded of shape N x C x T x H x W
        """
        with torch.no_grad():
            proj = self.backbone_mmt(x)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



