in train_helpers.py [0:0]
def load_vaes(H, logprint):
vae = VAE(H)
if H.restore_path:
logprint(f'Restoring vae from {H.restore_path}')
restore_params(vae, H.restore_path, map_cpu=True, local_rank=H.local_rank, mpi_size=H.mpi_size)
ema_vae = VAE(H)
if H.restore_ema_path:
logprint(f'Restoring ema vae from {H.restore_ema_path}')
restore_params(ema_vae, H.restore_ema_path, map_cpu=True, local_rank=H.local_rank, mpi_size=H.mpi_size)
else:
ema_vae.load_state_dict(vae.state_dict())
ema_vae.requires_grad_(False)
vae = vae.cuda(H.local_rank)
ema_vae = ema_vae.cuda(H.local_rank)
vae = DistributedDataParallel(vae, device_ids=[H.local_rank], output_device=H.local_rank)
if len(list(vae.named_parameters())) != len(list(vae.parameters())):
raise ValueError('Some params are not named. Please name all params.')
total_params = 0
for name, p in vae.named_parameters():
total_params += np.prod(p.shape)
logprint(total_params=total_params, readable=f'{total_params:,}')
return vae, ema_vae