in rlkit/launchers/skewfit_experiments.py [0:0]
def train_vae(variant, return_data=False):
from rlkit.util.ml_util import PiecewiseLinearSchedule
from rlkit.torch.vae.conv_vae import (
ConvVAE,
)
import rlkit.torch.vae.conv_vae as conv_vae
from rlkit.torch.vae.vae_trainer import ConvVAETrainer
from rlkit.core import logger
import rlkit.torch.pytorch_util as ptu
from rlkit.pythonplusplus import identity
import torch
beta = variant["beta"]
representation_size = variant["representation_size"]
generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
generate_vae_dataset)
train_data, test_data, info = generate_vae_dataset_fctn(
variant['generate_vae_dataset_kwargs']
)
logger.save_extra_data(info)
logger.get_snapshot_dir()
if 'beta_schedule_kwargs' in variant:
beta_schedule = PiecewiseLinearSchedule(
**variant['beta_schedule_kwargs'])
else:
beta_schedule = None
if variant.get('decoder_activation', None) == 'sigmoid':
decoder_activation = torch.nn.Sigmoid()
else:
decoder_activation = identity
architecture = variant['vae_kwargs'].get('architecture', None)
if not architecture and variant.get('imsize') == 84:
architecture = conv_vae.imsize84_default_architecture
elif not architecture and variant.get('imsize') == 48:
architecture = conv_vae.imsize48_default_architecture
variant['vae_kwargs']['architecture'] = architecture
variant['vae_kwargs']['imsize'] = variant.get('imsize')
m = ConvVAE(
representation_size,
decoder_output_activation=decoder_activation,
**variant['vae_kwargs']
)
m.to(ptu.device)
t = ConvVAETrainer(train_data, test_data, m, beta=beta,
beta_schedule=beta_schedule, **variant['algo_kwargs'])
save_period = variant['save_period']
dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
for epoch in range(variant['num_epochs']):
should_save_imgs = (epoch % save_period == 0)
t.train_epoch(epoch)
t.test_epoch(
epoch,
save_reconstruction=should_save_imgs,
# save_vae=False,
)
if should_save_imgs:
t.dump_samples(epoch)
t.update_train_weights()
logger.save_extra_data(m, 'vae.pkl', mode='pickle')
if return_data:
return m, train_data, test_data
return m