generate_creative_birds.py [28:56]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
class Initialstroke_Dataset(data.Dataset):
    def __init__(self, folder, image_size):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for p in Path(f'{folder}').glob(f'**/*.png')]
        self.transform = transforms.Compose([
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = self.transform(Image.open(path))
        return img

    def sample(self, n):
        sample_ids = [np.random.randint(self.__len__()) for _ in range(n)]
        samples = [self.transform(Image.open(self.paths[sample_id])) for sample_id in sample_ids]
        return torch.stack(samples).cuda()

def load_latest(model_dir, name):
    model_dir = Path(model_dir)
    file_paths = [p for p in Path(model_dir / name).glob('model_*.pt')]
    saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths))
    if len(saved_nums) == 0:
        return
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



generate_creative_creatures.py [32:61]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
class Initialstroke_Dataset(data.Dataset):
    def __init__(self, folder, image_size):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for p in Path(f'{folder}').glob(f'**/*.png')]
        self.transform = transforms.Compose([
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = self.transform(Image.open(path))
        return img

    def sample(self, n):
        sample_ids = [np.random.randint(self.__len__()) for _ in range(n)]
        samples = [self.transform(Image.open(self.paths[sample_id])) for sample_id in sample_ids]
        return torch.stack(samples).cuda()


def load_latest(model_dir, name):
    model_dir = Path(model_dir)
    file_paths = [p for p in Path(model_dir / name).glob('model_*.pt')]
    saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths))
    if len(saved_nums) == 0:
        return
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



