gym/misc/write_rollout_data.py (41 lines of code) (raw):
"""
This script does a few rollouts with an environment and writes the data to an npz file
Its purpose is to help with verifying that you haven't functionally changed an environment.
(If you have, you should bump the version number.)
"""
import argparse, numpy as np, collections, sys
from os import path
class RandomAgent(object):
def __init__(self, ac_space):
self.ac_space = ac_space
def act(self, _):
return self.ac_space.sample()
def rollout(env, agent, max_episode_steps):
"""
Simulate the env and agent for max_episode_steps
"""
ob = env.reset()
data = collections.defaultdict(list)
for _ in xrange(max_episode_steps):
data["observation"].append(ob)
action = agent.act(ob)
data["action"].append(action)
ob,rew,done,_ = env.step(action)
data["reward"].append(rew)
if done:
break
return data
def main():
parser = argparse.ArgumentParser()
parser.add_argument("envid")
parser.add_argument("outfile")
parser.add_argument("--gymdir")
args = parser.parse_args()
if args.gymdir:
sys.path.insert(0, args.gymdir)
import gym
from gym import utils
print utils.colorize("gym directory: %s"%path.dirname(gym.__file__), "yellow")
env = gym.make(args.envid)
agent = RandomAgent(env.action_space)
alldata = {}
for i in xrange(2):
np.random.seed(i)
data = rollout(env, agent, env.spec.max_episode_steps)
for (k, v) in data.items():
alldata["%i-%s"%(i, k)] = v
np.savez(args.outfile, **alldata)
if __name__ == "__main__":
main()