in maddpg/trainer/maddpg.py [0:0]
def update(self, agents, t):
if len(self.replay_buffer) < self.max_replay_buffer_len: # replay buffer is not large enough
return
if not t % 100 == 0: # only update every 100 steps
return
self.replay_sample_index = self.replay_buffer.make_index(self.args.batch_size)
# collect replay sample from all agents
obs_n = []
obs_next_n = []
act_n = []
index = self.replay_sample_index
for i in range(self.n):
obs, act, rew, obs_next, done = agents[i].replay_buffer.sample_index(index)
obs_n.append(obs)
obs_next_n.append(obs_next)
act_n.append(act)
obs, act, rew, obs_next, done = self.replay_buffer.sample_index(index)
# train q network
num_sample = 1
target_q = 0.0
for i in range(num_sample):
target_act_next_n = [agents[i].p_debug['target_act'](obs_next_n[i]) for i in range(self.n)]
target_q_next = self.q_debug['target_q_values'](*(obs_next_n + target_act_next_n))
target_q += rew + self.args.gamma * (1.0 - done) * target_q_next
target_q /= num_sample
q_loss = self.q_train(*(obs_n + act_n + [target_q]))
# train p network
p_loss = self.p_train(*(obs_n + act_n))
self.p_update()
self.q_update()
return [q_loss, p_loss, np.mean(target_q), np.mean(rew), np.mean(target_q_next), np.std(target_q)]