in part_selector.py [0:0]
def __init__(self, name, results_dir, models_dir, n_part, image_size, network_capacity, batch_size = 4,
gradient_accumulate_every=1, lr = 2e-4, num_workers = None, save_every = 1000):
self.clf = None
self.name = name
self.results_dir = Path(results_dir)
self.models_dir = Path(models_dir)
self.config_path = self.models_dir / name / '.config.json'
assert log2(image_size).is_integer(), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
self.n_part = n_part
self.image_size = image_size
self.network_capacity = network_capacity
self.lr = lr
self.batch_size = batch_size
self.num_workers = num_workers
self.save_every = save_every
self.steps = 0
self.gradient_accumulate_every = gradient_accumulate_every
self.d_loss = 0
self.d_acc = 0
self.loader = None
self.criterion = nn.CrossEntropyLoss()
if 'bird' in name:
self.target_parts = ['eye', 'head', 'body', 'beak', 'legs', 'wing', 'mouth', 'tail', 'none']
elif 'generic' in name or 'fin' in name or 'horn' in name:
self.target_parts = ['eye', 'arms', 'beak', 'mouth', 'body', 'ears', 'feet', 'fin',
'hair', 'hands', 'head', 'horns', 'legs', 'nose', 'paws', 'tail', 'wings', 'none']
self.n_part_class = len(self.target_parts)