equivariance_measure/embedding_alignments.py [75:103]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        if self.model_pth:
            model = self.load_pretrained_ddp(model)
        embedding_model = torch.nn.Sequential(*(list(model.children())[:-1]))
        return embedding_model

    def load_pretrained_ddp(self, model) -> torch.nn.Module:
        state_dict = torch.load(self.model_pth)["state_dict"]
        # create new OrderedDict that does not contain `module.`
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        # load params
        model.load_state_dict(new_state_dict)
        return model

    def forward(self, x):
        return self.model(x).squeeze()

    def test_step(self, batch, batch_idx):
        x, labels = batch
        z = self(x)

        for magnitude_idx in range(10):
            transform = transformations.Transformation(
                self.transform_name, magnitude_idx
            )
            x_t = transform(x)
            z_t = self(x_t)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



equivariance_measure/embedding_distances.py [139:167]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        if self.model_pth:
            model = self.load_pretrained_ddp(model)
        embedding_model = torch.nn.Sequential(*(list(model.children())[:-1]))
        return embedding_model

    def load_pretrained_ddp(self, model) -> torch.nn.Module:
        state_dict = torch.load(self.model_pth)["state_dict"]
        # create new OrderedDict that does not contain `module.`
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        # load params
        model.load_state_dict(new_state_dict)
        return model

    def forward(self, x):
        return self.model(x).squeeze()

    def test_step(self, batch, batch_idx):
        x, labels = batch
        z = self(x)

        for magnitude_idx in range(10):
            transform = transformations.Transformation(
                self.transform_name, magnitude_idx
            )
            x_t = transform(x)
            z_t = self(x_t)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



