in qlearn/atari/prior_bootstrapped_agent.py [0:0]
def learn(self, states, actions, rewards, next_states, terminals):
self.online_net.train()
self.target_net.train()
states = Variable(self.FloatTensor(states / 255.0))
actions = Variable(self.LongTensor(actions))
next_states = Variable(self.FloatTensor(next_states / 255.0))
rewards = Variable(self.FloatTensor(rewards)).view(-1, 1)
terminals = Variable(self.FloatTensor(terminals)).view(-1, 1)
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken
online_prior_outputs = self.prior(states)
online_outputs = self.online_net(states)
# online_values = online_prior_outputs + online_outputs
target_prior_outputs = self.prior(next_states)
target_outputs = self.target_net(next_states)
# import pdb; pdb.set_trace()
loss = 0
for k in range(self.nheads):
online_prior_output_ = online_prior_outputs[k].detach()
online_output_ = online_outputs[k]
online_value = self.beta * online_prior_output_ + online_output_
state_action_values = online_value.gather(1, actions.view(-1, 1))
target_prior_output_ = target_prior_outputs[k]
target_output_ = target_outputs[k]
target_value = self.beta * target_prior_output_ + target_output_
next_state_values = target_value.max(1)[0].view(-1, 1)
target_state_action_values = rewards + (1 - terminals) * self.discount * next_state_values
loss += F.smooth_l1_loss(state_action_values, target_state_action_values.detach())
# Optimize the model
self.optimiser.zero_grad()
loss.backward()
clip_grad_norm_(self.online_net.parameters(), 10)
# for param in self.online_net.parameters():
# param.grad.data.clamp_(-1, 1)
self.optimiser.step()
return loss