in src/train.py [0:0]
def __init__(self, args):
self.args = args
self.args.n_datasets = len(self.args.data)
self.expPath = Path('checkpoints') / args.expName
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
self.logger = create_output_dir(args, self.expPath)
self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]
assert not args.distributed or len(self.data) == int(
os.environ['WORLD_SIZE']), "Number of datasets must match number of nodes"
self.losses_recon = [LossMeter(f'recon {i}') for i in range(self.args.n_datasets)]
self.loss_d_right = LossMeter('d')
self.loss_total = LossMeter('total')
self.evals_recon = [LossMeter(f'recon {i}') for i in range(self.args.n_datasets)]
self.eval_d_right = LossMeter('eval d')
self.eval_total = LossMeter('eval total')
self.encoder = Encoder(args)
self.decoder = WaveNet(args)
self.discriminator = ZDiscriminator(args)
if args.checkpoint:
checkpoint_args_path = os.path.dirname(args.checkpoint) + '/args.pth'
checkpoint_args = torch.load(checkpoint_args_path)
self.start_epoch = checkpoint_args[-1] + 1
states = torch.load(args.checkpoint)
self.encoder.load_state_dict(states['encoder_state'])
self.decoder.load_state_dict(states['decoder_state'])
self.discriminator.load_state_dict(states['discriminator_state'])
self.logger.info('Loaded checkpoint parameters')
else:
self.start_epoch = 0
if args.distributed:
self.encoder.cuda()
self.encoder = torch.nn.parallel.DistributedDataParallel(self.encoder)
self.discriminator.cuda()
self.discriminator = torch.nn.parallel.DistributedDataParallel(self.discriminator)
self.logger.info('Created DistributedDataParallel')
else:
self.encoder = torch.nn.DataParallel(self.encoder).cuda()
self.discriminator = torch.nn.DataParallel(self.discriminator).cuda()
self.decoder = torch.nn.DataParallel(self.decoder).cuda()
self.model_optimizer = optim.Adam(chain(self.encoder.parameters(),
self.decoder.parameters()),
lr=args.lr)
self.d_optimizer = optim.Adam(self.discriminator.parameters(),
lr=args.lr)
if args.checkpoint and args.load_optimizer:
self.model_optimizer.load_state_dict(states['model_optimizer_state'])
self.d_optimizer.load_state_dict(states['d_optimizer_state'])
self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(self.model_optimizer, args.lr_decay)
self.lr_manager.last_epoch = self.start_epoch
self.lr_manager.step()