in rlkit/launchers/skewfit_experiments.py [0:0]
def get_video_save_func(rollout_function, env, policy, variant):
logdir = logger.get_snapshot_dir()
save_period = variant.get('save_video_period', 50)
do_state_exp = variant.get("do_state_exp", False)
dump_video_kwargs = variant.get("dump_video_kwargs", dict())
if do_state_exp:
imsize = variant.get('imsize')
dump_video_kwargs['imsize'] = imsize
image_env = ImageEnv(
env,
imsize,
init_camera=variant.get('init_camera', None),
transpose=True,
normalize=True,
)
def save_video(algo, epoch):
if epoch % save_period == 0 or epoch == algo.num_epochs:
filename = osp.join(logdir,
'video_{epoch}_env.mp4'.format(epoch=epoch))
dump_video(image_env, policy, filename, rollout_function,
**dump_video_kwargs)
else:
image_env = env
dump_video_kwargs['imsize'] = env.imsize
def save_video(algo, epoch):
if epoch % save_period == 0 or epoch == algo.num_epochs:
filename = osp.join(logdir,
'video_{epoch}_env.mp4'.format(epoch=epoch))
temporary_mode(
image_env,
mode='video_env',
func=dump_video,
args=(image_env, policy, filename, rollout_function),
kwargs=dump_video_kwargs
)
filename = osp.join(logdir,
'video_{epoch}_vae.mp4'.format(epoch=epoch))
temporary_mode(
image_env,
mode='video_vae',
func=dump_video,
args=(image_env, policy, filename, rollout_function),
kwargs=dump_video_kwargs
)
return save_video