qlearn/atari/train_bootstrapped_agent.py [142:191]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    replay_buffer = ReplayBuffer(args.replay_buffer_size)

    start_time, start_steps = None, None
    steps_per_iter = RunningAvg(0.999)
    iteration_time_est = RunningAvg(0.999)
    obs = env.reset()
    num_iters = 0
    num_episodes = 0
    num_updates = 0
    prev_lives = None
    episode_rewards = [0.0]
    td_errors_list = []
    best_score = None
    k = random.randrange(args.nheads)
    while True:

        num_iters += 1
        # Take action and store transition in the replay buffer.
        if num_iters <= args.learning_starts:
            action = random.randrange(num_actions)
        else:
            # Reshape state to (1, channels, x_dim, y_dim)
            action = agent.act_single_head(np.transpose(np.array(obs)[None], [0, 3, 1, 2]), k)
        # import pdb
        # pdb.set_trace()
        new_obs, rew, done, info = env.step(action)
        death = done or (prev_lives is not None and info['ale.lives'] < prev_lives and info['ale.lives'] > 0)
        prev_lives = info['ale.lives']

        replay_buffer.add(obs, action, np.sign(rew), new_obs, float(death))
        obs = new_obs
        episode_rewards[-1] += rew

        if done:
            log.add_scalar('reward', episode_rewards[-1], num_iters)
            episode_rewards.append(0.0)
            obs = env.reset()
            num_episodes += 1

        if num_iters > args.learning_starts and num_iters % args.learning_freq == 0:

            obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(args.batch_size)
            # Reshape state to (batch, channels, x_dim, y_dim)
            obses_t = np.transpose(obses_t, [0, 3, 1, 2])
            obses_tp1 = np.transpose(obses_tp1, [0, 3, 1, 2])

            # TODO
            td_errors = agent.learn(obses_t, actions, rewards, obses_tp1, dones)
            td_errors_list.append(td_errors.item())
            log.add_scalar('td_error', td_errors.item(), num_iters)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



qlearn/atari/train_prior_bootstrapped_agent.py [143:191]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    replay_buffer = ReplayBuffer(args.replay_buffer_size)

    start_time, start_steps = None, None
    steps_per_iter = RunningAvg(0.999)
    iteration_time_est = RunningAvg(0.999)
    obs = env.reset()
    num_iters = 0
    num_episodes = 0
    num_updates = 0
    prev_lives = None
    episode_rewards = [0.0]
    td_errors_list = []
    best_score = None
    k = random.randrange(args.nheads)
    while True:

        num_iters += 1
        # Take action and store transition in the replay buffer.
        if num_iters <= args.learning_starts:
            action = random.randrange(num_actions)
        else:
            # Reshape state to (1, channels, x_dim, y_dim)
            action = agent.act_single_head(np.transpose(np.array(obs)[None], [0, 3, 1, 2]), k)
        # import pdb
        # pdb.set_trace()
        new_obs, rew, done, info = env.step(action)
        death = done or (prev_lives is not None and info['ale.lives'] < prev_lives and info['ale.lives'] > 0)
        prev_lives = info['ale.lives']
        replay_buffer.add(obs, action, np.sign(rew), new_obs, float(death))
        obs = new_obs
        episode_rewards[-1] += rew

        if done:
            log.add_scalar('reward', episode_rewards[-1], num_iters)
            episode_rewards.append(0.0)
            obs = env.reset()
            num_episodes += 1

        if num_iters > args.learning_starts and num_iters % args.learning_freq == 0:

            obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(args.batch_size)
            # Reshape state to (batch, channels, x_dim, y_dim)
            obses_t = np.transpose(obses_t, [0, 3, 1, 2])
            obses_tp1 = np.transpose(obses_tp1, [0, 3, 1, 2])

            # TODO
            td_errors = agent.learn(obses_t, actions, rewards, obses_tp1, dones)
            td_errors_list.append(td_errors.item())
            log.add_scalar('td_error', td_errors.item(), num_iters)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



