experiments/overlap/datasets.py [74:115]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            return self.aug.convert_from_numpy(params)

    def __getitem__(self, index):
        pre_im, label = self.dataset[index]
        pre_im = self.pretransform(pre_im)
        params = self.get_random_transform()
        return self.posttransform(self.aug.transform(pre_im, **params)), label

    def __len__(self):
        return len(self.dataset)


    def fixed_transform(self, index, transform_index):
        assert self.transform_list is not None, "Must have a fixed transform list to generate fixed transforms."
        im, label = self.dataset[index]
        im = self.pretransform(im)
        params = self.aug.convert_from_numpy(self.transform_list[transform_index])
        im = self.aug.transform(im, **params)
        return self.posttransform(im), label
        
    def serialize(self, indices=None):
        '''
        Returns a new dataset that is all fixed transforms in order,
        applied to each index in order.
        '''
        class SerialDataset(torch.utils.data.Dataset):
            def __init__(self, dataset, indices=None):
                self.dataset = dataset
                self.indices = indices

            def __getitem__(self, index):
                im_idx = index // len(self.dataset.transform_list)
                im_idx = self.indices[im_idx] if self.indices is not None else im_idx
                param_idx = index % len(self.dataset.transform_list)
                return self.dataset.fixed_transform(im_idx, param_idx)

            def __len__(self):
                if self.indices is not None:
                    return len(self.indices) * len(self.dataset.transform_list)
                else:
                    return len(self.dataset) * len(self.dataset.transform_list)
        return SerialDataset(self, indices)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



experiments/overlap/datasets.py [330:371]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            return self.aug.convert_from_numpy(params)

    def __getitem__(self, index):
        pre_im, label = self.dataset[index]
        pre_im = self.pretransform(pre_im)
        params = self.get_random_transform()
        return self.posttransform(self.aug.transform(pre_im, **params)), label

    def __len__(self):
        return len(self.dataset)


    def fixed_transform(self, index, transform_index):
        assert self.transform_list is not None, "Must have a fixed transform list to generate fixed transforms."
        im, label = self.dataset[index]
        im = self.pretransform(im)
        params = self.aug.convert_from_numpy(self.transform_list[transform_index])
        im = self.aug.transform(im, **params)
        return self.posttransform(im), label
        
    def serialize(self, indices=None):
        '''
        Returns a new dataset that is all fixed transforms in order,
        applied to each index in order.
        '''
        class SerialDataset(torch.utils.data.Dataset):
            def __init__(self, dataset, indices=None):
                self.dataset = dataset
                self.indices = indices

            def __getitem__(self, index):
                im_idx = index // len(self.dataset.transform_list)
                im_idx = self.indices[im_idx] if self.indices is not None else im_idx
                param_idx = index % len(self.dataset.transform_list)
                return self.dataset.fixed_transform(im_idx, param_idx)

            def __len__(self):
                if self.indices is not None:
                    return len(self.indices) * len(self.dataset.transform_list)
                else:
                    return len(self.dataset) * len(self.dataset.transform_list)
        return SerialDataset(self, indices)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



