in train/learner.py [0:0]
def learner_loop(flags, rank=0, barrier=None, gossip_buffer=None, stop_event=None):
if flags.num_actors < flags.batch_size:
logging.warn("Batch size is larger than number of actors.")
assert (
flags.batch_size % flags.inference_batch_size == 0
), "For now, inference_batch_size must divide batch_size"
assert (
flags.num_actors >= flags.inference_batch_size
), "Inference batch size must be <= number of actors"
if flags.logdir:
log_file_path = os.path.join(flags.logdir, "logs.tsv")
logging.info(
"%s logs to %s",
"Appending" if os.path.exists(log_file_path) else "Writing",
log_file_path,
)
else:
logging.warn("--logdir not set. Not writing logs to file.")
if not flags.cc_env_use_state_summary:
raise NotImplementedError(
"Setting cc_env_use_state_summary=False is currently not supported. The fetching "
"of throughput and delay statistics through the states assumes that we are provided "
"summary statistics, and would need to be updated to work without."
)
unroll_queue = queue.Queue(maxsize=1)
log_queue = queue.Queue()
# Inference model.
model = make_train_model(flags)
# Dummy (observation, reward, done)
dummy_env_output = (
np.zeros(flags.observation_length, dtype=np.float32),
np.array(0, dtype=np.float32),
np.array(True, dtype=np.bool),
)
dummy_env_output = nest.map(
lambda a: torch.from_numpy(np.array(a)), dummy_env_output
)
with torch.no_grad():
dummy_model_output, _ = model(
last_actions=torch.zeros([1], dtype=torch.int64),
env_outputs=nest.map(lambda t: t.unsqueeze(0), dummy_env_output),
core_state=model.initial_state(1),
)
dummy_model_output = nest.map(lambda t: t.squeeze(0), dummy_model_output)
model = model.to(device=flags.inference_device)
# TODO: Decide if we really want that for simple tensors?
actions = StructuredBuffer(torch.zeros([flags.num_actors], dtype=torch.int64))
actor_run_ids = StructuredBuffer(torch.zeros([flags.num_actors], dtype=torch.int64))
actor_infos = StructuredBuffer(
dict(
episode_step=torch.zeros([flags.num_actors], dtype=torch.int64),
episode_return=torch.zeros([flags.num_actors]),
cwnd_mean=torch.zeros([flags.num_actors]),
delay_mean=torch.zeros([flags.num_actors]),
throughput_mean=torch.zeros([flags.num_actors]),
train_job_id=torch.zeros([flags.num_actors], dtype=torch.int64),
)
)
# Agent states at the beginning of an unroll. Needs to be kept for learner.
# A state is a tuple of two tensors of shape [num_actors, hidden_size + 1]):
# - Why two tensors? Because the LSTM cell's state contains both its output and its
# internal state.
# - Why +1? Because the reward is appended to the hidden layer being fed to the LSTM,
# and in `SimpleNet` the LSTM cell has the same output size as its input size.
initial_states = model.initial_state(batch_size=flags.num_actors)
first_agent_states = StructuredBuffer(initial_states)
# Current agent states.
agent_states = StructuredBuffer(copy.deepcopy(initial_states))
rollouts = Rollouts(
dict(
last_actions=torch.zeros((), dtype=torch.int64),
env_outputs=dummy_env_output,
actor_outputs=dummy_model_output,
),
unroll_length=flags.unroll_length,
num_actors=flags.num_actors,
)
server = torchbeast.Server(flags.server_address, max_parallel_calls=4)
def inference(actor_ids, run_ids, env_outputs):
"""
Compute actions for a subset of actors, based on their observations.
:param actor_ids: 1D tensor with the indices of actors whose actions
are requested.
:param run_ids: 1D tensor of same length as `actor_ids`, with the
associated run IDs (used to deal with preemption / crashing: a
fresh env is expected to provide a new run ID to be sure we do
not accidentally re-use outdated data). Note that this mechanism
is not actually needed by the current RL congestion control env.
:param env_outputs: Tuple of observation (N x obs_dim), reward (N) and
done (N) tensors, with N the size of `actor_ids`.
"""
torch.set_grad_enabled(False)
previous_run_ids = actor_run_ids.get(actor_ids)
reset_indices = previous_run_ids != run_ids
actor_run_ids.set(actor_ids, run_ids)
actors_needing_reset = actor_ids[reset_indices]
# Update new/restarted actors.
# NB: this never happens because `runId` is always set to 0 in
# `CongestionControlRPCEnv::makeCallRequest()``. This is working as
# intended (calling `reset()` is not needed here).
if actors_needing_reset.numel():
logging.info("Actor ids needing reset: %s", actors_needing_reset.tolist())
actor_infos.clear(actors_needing_reset)
rollouts.reset(actors_needing_reset)
actions.clear(actors_needing_reset)
initial_agent_states = model.initial_state(actors_needing_reset.numel())
first_agent_states.set(actors_needing_reset, initial_agent_states)
agent_states.set(actors_needing_reset, initial_agent_states)
obs, reward, done = env_outputs
# Update logging stats at end of episode.
done_ids = actor_ids[done]
if done_ids.numel():
# Do not log stats of zero-length episodes (typically these should
# only happen on the very first episode, due to the `done` flag
# being true without having any episode before).
valid_done_ids = done_ids[actor_infos.get(done_ids)["episode_step"] > 0]
log_queue.put((valid_done_ids, actor_infos.get(valid_done_ids)))
# NB: `actor_infos.get()` returned a copy of the data, so it is ok
# to clear it now.
actor_infos.clear(done_ids)
# Clear reward for agents that are done: it is meaningless as it is obtained with
# the first observation, before the agent got a chance to take any action.
reward[done] = 0.0
# We only update the `train_job_id` field once (when an episode starts, which
# is when `done` is True). This is because it remains the same throughout the
# whole episode.
extra_info = {"train_job_id": (obs[:, -1] * done).long()}
else:
extra_info = {}
actor_infos.add(
actor_ids,
dict(
episode_step=1,
episode_return=reward,
cwnd_mean=state.get_mean(obs, state.Field.CWND, dim=1)
* flags.cc_env_norm_bytes
/ UDP_SEND_PACKET_LEN,
delay_mean=state.get_mean(obs, state.Field.DELAY, dim=1)
* flags.cc_env_norm_ms,
throughput_mean=state.get_mean(obs, state.Field.THROUGHPUT, dim=1)
* flags.cc_env_norm_bytes
/ 1024 ** 2, # convert to Mbytes/s
**extra_info,
),
)
last_actions = actions.get(actor_ids)
prev_agent_states = agent_states.get(actor_ids)
actor_outputs, new_agent_states = model(
*nest.map(
lambda t: t.to(flags.inference_device),
(last_actions, env_outputs, prev_agent_states),
)
)
actor_outputs, new_agent_states = nest.map(
lambda t: t.cpu(), (actor_outputs, new_agent_states)
)
timestep = dict(
last_actions=last_actions,
env_outputs=env_outputs,
actor_outputs=actor_outputs,
)
completed_ids, unrolls = rollouts.append(actor_ids, timestep)
if completed_ids.numel():
try:
unroll_queue.put(
(completed_ids, unrolls, first_agent_states.get(completed_ids)),
timeout=5.0,
)
except queue.Closed:
if server.running():
raise
# Importantly, `first_agent_states` must contain the states *before* processing
# the current batch of data, because in `rollouts` we always start from the last
# item of the previous rollout (and we need the state before that one). This is
# why this line must happen before the update to `agent_states` below.
first_agent_states.set(completed_ids, agent_states.get(completed_ids))
agent_states.set(actor_ids, new_agent_states)
action = actor_outputs["action"]
actions.set(actor_ids, action)
return action
server.bind("inference", inference, batch_size=flags.inference_batch_size)
server.run()
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
try:
learn(
model,
executor,
unroll_queue,
log_queue,
flags,
rank,
barrier,
gossip_buffer,
stop_event=stop_event,
)
except KeyboardInterrupt:
print("Stopping ...")
finally:
unroll_queue.close()
server.stop()