in train.py [0:0]
def main():
args = parse_args()
utils.set_seed_everywhere(args.seed)
if args.domain_name == 'carla':
env = CarlaEnv(
render_display=args.render, # for local debugging only
display_text=args.render, # for local debugging only
changing_weather_speed=0.1, # [0, +inf)
rl_image_size=args.image_size,
max_episode_steps=1000,
frame_skip=args.action_repeat,
is_other_cars=True,
port=args.port
)
# TODO: implement env.seed(args.seed) ?
eval_env = env
else:
env = dmc2gym.make(
domain_name=args.domain_name,
task_name=args.task_name,
resource_files=args.resource_files,
img_source=args.img_source,
total_frames=args.total_frames,
seed=args.seed,
visualize_reward=False,
from_pixels=(args.encoder_type == 'pixel'),
height=args.image_size,
width=args.image_size,
frame_skip=args.action_repeat
)
env.seed(args.seed)
eval_env = dmc2gym.make(
domain_name=args.domain_name,
task_name=args.task_name,
resource_files=args.eval_resource_files,
img_source=args.img_source,
total_frames=args.total_frames,
seed=args.seed,
visualize_reward=False,
from_pixels=(args.encoder_type == 'pixel'),
height=args.image_size,
width=args.image_size,
frame_skip=args.action_repeat
)
# stack several consecutive frames together
if args.encoder_type.startswith('pixel'):
env = utils.FrameStack(env, k=args.frame_stack)
eval_env = utils.FrameStack(eval_env, k=args.frame_stack)
utils.make_dir(args.work_dir)
video_dir = utils.make_dir(os.path.join(args.work_dir, 'video'))
model_dir = utils.make_dir(os.path.join(args.work_dir, 'model'))
buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer'))
video = VideoRecorder(video_dir if args.save_video else None)
with open(os.path.join(args.work_dir, 'args.json'), 'w') as f:
json.dump(vars(args), f, sort_keys=True, indent=4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# the dmc2gym wrapper standardizes actions
assert env.action_space.low.min() >= -1
assert env.action_space.high.max() <= 1
replay_buffer = utils.ReplayBuffer(
obs_shape=env.observation_space.shape,
action_shape=env.action_space.shape,
capacity=args.replay_buffer_capacity,
batch_size=args.batch_size,
device=device
)
agent = make_agent(
obs_shape=env.observation_space.shape,
action_shape=env.action_space.shape,
args=args,
device=device
)
L = Logger(args.work_dir, use_tb=args.save_tb)
episode, episode_reward, done = 0, 0, True
start_time = time.time()
for step in range(args.num_train_steps):
if done:
if args.decoder_type == 'inverse':
for i in range(1, args.k): # fill k_obs with 0s if episode is done
replay_buffer.k_obses[replay_buffer.idx - i] = 0
if step > 0:
L.log('train/duration', time.time() - start_time, step)
start_time = time.time()
L.dump(step)
# evaluate agent periodically
if episode % args.eval_freq == 0:
L.log('eval/episode', episode, step)
evaluate(eval_env, agent, video, args.num_eval_episodes, L, step)
if args.save_model:
agent.save(model_dir, step)
if args.save_buffer:
replay_buffer.save(buffer_dir)
L.log('train/episode_reward', episode_reward, step)
obs = env.reset()
done = False
episode_reward = 0
episode_step = 0
episode += 1
reward = 0
L.log('train/episode', episode, step)
# sample action for data collection
if step < args.init_steps:
action = env.action_space.sample()
else:
with utils.eval_mode(agent):
action = agent.sample_action(obs)
# run training update
if step >= args.init_steps:
num_updates = args.init_steps if step == args.init_steps else 1
for _ in range(num_updates):
agent.update(replay_buffer, L, step)
curr_reward = reward
next_obs, reward, done, _ = env.step(action)
# allow infinit bootstrap
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
done
)
episode_reward += reward
replay_buffer.add(obs, action, curr_reward, reward, next_obs, done_bool)
np.copyto(replay_buffer.k_obses[replay_buffer.idx - args.k], next_obs)
obs = next_obs
episode_step += 1