in distributed/rpc/batch/reinforce.py [0:0]
def __init__(self, world_size, batch=True):
self.ob_rrefs = []
self.agent_rref = RRef(self)
self.rewards = {}
self.policy = Policy(batch).cuda()
self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
self.running_reward = 0
for ob_rank in range(1, world_size):
ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
self.ob_rrefs.append(remote(ob_info, Observer, args=(batch,)))
self.rewards[ob_info.id] = []
self.states = torch.zeros(len(self.ob_rrefs), 1, 4)
self.batch = batch
# With batching, saved_log_probs contains a list of tensors, where each
# tensor contains probs from all observers in one step.
# Without batching, saved_log_probs is a dictionary where the key is the
# observer id and the value is a list of probs for that observer.
self.saved_log_probs = [] if self.batch else {k:[] for k in range(len(self.ob_rrefs))}
self.future_actions = torch.futures.Future()
self.lock = threading.Lock()
self.pending_states = len(self.ob_rrefs)