in rlkit/torch/vae/vae_trainer.py [0:0]
def train_epoch(self, epoch, sample_batch=None, batches=100, from_rl=False):
self.model.train()
losses = []
log_probs = []
kles = []
zs = []
beta = float(self.beta_schedule.get_value(epoch))
for batch_idx in range(batches):
if sample_batch is not None:
data = sample_batch(self.batch_size, epoch)
# obs = data['obs']
next_obs = data['next_obs']
# actions = data['actions']
else:
next_obs = self.get_batch(epoch=epoch)
obs = None
actions = None
self.optimizer.zero_grad()
reconstructions, obs_distribution_params, latent_distribution_params = self.model(next_obs)
log_prob = self.model.logprob(next_obs, obs_distribution_params)
kle = self.model.kl_divergence(latent_distribution_params)
encoder_mean = self.model.get_encoding_from_latent_distribution_params(latent_distribution_params)
z_data = ptu.get_numpy(encoder_mean.cpu())
for i in range(len(z_data)):
zs.append(z_data[i, :])
loss = -1 * log_prob + beta * kle
self.optimizer.zero_grad()
loss.backward()
losses.append(loss.item())
log_probs.append(log_prob.item())
kles.append(kle.item())
self.optimizer.step()
if self.log_interval and batch_idx % self.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch,
batch_idx * len(data),
len(self.train_loader.dataset),
100. * batch_idx / len(self.train_loader),
loss.item() / len(next_obs)))
if not from_rl:
zs = np.array(zs)
self.model.dist_mu = zs.mean(axis=0)
self.model.dist_std = zs.std(axis=0)
self.eval_statistics['train/log prob'] = np.mean(log_probs)
self.eval_statistics['train/KL'] = np.mean(kles)
self.eval_statistics['train/loss'] = np.mean(losses)