in src/utils.py [0:0]
def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue,
model: torch.nn.Module, buffers: Buffers,
episode_state_count_dict: dict, train_state_count_dict: dict,
initial_agent_state_buffers, flags):
try:
log.info('Actor %i started.', i)
timings = prof.Timings()
gym_env = create_env(flags)
seed = i ^ int.from_bytes(os.urandom(4), byteorder='little')
gym_env.seed(seed)
if flags.num_input_frames > 1:
gym_env = FrameStack(gym_env, flags.num_input_frames)
env = Environment(gym_env, fix_seed=flags.fix_seed, env_seed=flags.env_seed)
env_output = env.initial()
agent_state = model.initial_state(batch_size=1)
agent_output, unused_state = model(env_output, agent_state)
while True:
index = free_queue.get()
if index is None:
break
# Write old rollout end.
for key in env_output:
buffers[key][index][0, ...] = env_output[key]
for key in agent_output:
buffers[key][index][0, ...] = agent_output[key]
for i, tensor in enumerate(agent_state):
initial_agent_state_buffers[index][i][...] = tensor
# Update the episodic state counts
episode_state_key = tuple(env_output['frame'].view(-1).tolist())
if episode_state_key in episode_state_count_dict:
episode_state_count_dict[episode_state_key] += 1
else:
episode_state_count_dict.update({episode_state_key: 1})
buffers['episode_state_count'][index][0, ...] = \
torch.tensor(1 / np.sqrt(episode_state_count_dict.get(episode_state_key)))
# Reset the episode state counts when the episode is over
if env_output['done'][0][0]:
for episode_state_key in episode_state_count_dict:
episode_state_count_dict = dict()
# Update the training state counts if you're doing count-based exploration
if flags.model == 'count':
train_state_key = tuple(env_output['frame'].view(-1).tolist())
if train_state_key in train_state_count_dict:
train_state_count_dict[train_state_key] += 1
else:
train_state_count_dict.update({train_state_key: 1})
buffers['train_state_count'][index][0, ...] = \
torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key)))
# Do new rollout
for t in range(flags.unroll_length):
timings.reset()
with torch.no_grad():
agent_output, agent_state = model(env_output, agent_state)
timings.time('model')
env_output = env.step(agent_output['action'])
timings.time('step')
for key in env_output:
buffers[key][index][t + 1, ...] = env_output[key]
for key in agent_output:
buffers[key][index][t + 1, ...] = agent_output[key]
# Update the episodic state counts
episode_state_key = tuple(env_output['frame'].view(-1).tolist())
if episode_state_key in episode_state_count_dict:
episode_state_count_dict[episode_state_key] += 1
else:
episode_state_count_dict.update({episode_state_key: 1})
buffers['episode_state_count'][index][t + 1, ...] = \
torch.tensor(1 / np.sqrt(episode_state_count_dict.get(episode_state_key)))
# Reset the episode state counts when the episode is over
if env_output['done'][0][0]:
episode_state_count_dict = dict()
# Update the training state counts if you're doing count-based exploration
if flags.model == 'count':
train_state_key = tuple(env_output['frame'].view(-1).tolist())
if train_state_key in train_state_count_dict:
train_state_count_dict[train_state_key] += 1
else:
train_state_count_dict.update({train_state_key: 1})
buffers['train_state_count'][index][t + 1, ...] = \
torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key)))
timings.time('write')
full_queue.put(index)
if i == 0:
log.info('Actor %i: %s', i, timings.summary())
except KeyboardInterrupt:
pass
except Exception as e:
logging.error('Exception in worker process %i', i)
traceback.print_exc()
print()
raise e