def learn()

in qlearn/atari/prior_bootstrapped_agent.py [0:0]


    def learn(self, states, actions, rewards, next_states, terminals):
        self.online_net.train()
        self.target_net.train()
        states = Variable(self.FloatTensor(states / 255.0))
        actions = Variable(self.LongTensor(actions))
        next_states = Variable(self.FloatTensor(next_states / 255.0))
        rewards = Variable(self.FloatTensor(rewards)).view(-1, 1)
        terminals = Variable(self.FloatTensor(terminals)).view(-1, 1)

        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
        # columns of actions taken
        online_prior_outputs = self.prior(states)
        online_outputs = self.online_net(states)
        # online_values = online_prior_outputs + online_outputs

        target_prior_outputs = self.prior(next_states)
        target_outputs = self.target_net(next_states)
        # import pdb; pdb.set_trace()

        loss = 0
        for k in range(self.nheads):
            online_prior_output_ = online_prior_outputs[k].detach()
            online_output_ = online_outputs[k]
            online_value = self.beta * online_prior_output_ + online_output_
            state_action_values = online_value.gather(1, actions.view(-1, 1))

            target_prior_output_ = target_prior_outputs[k]
            target_output_ = target_outputs[k]
            target_value = self.beta * target_prior_output_ + target_output_
            next_state_values = target_value.max(1)[0].view(-1, 1)

            target_state_action_values = rewards + (1 - terminals) * self.discount * next_state_values

            loss += F.smooth_l1_loss(state_action_values, target_state_action_values.detach())

        # Optimize the model
        self.optimiser.zero_grad()
        loss.backward()
        clip_grad_norm_(self.online_net.parameters(), 10)
        # for param in self.online_net.parameters():
        #     param.grad.data.clamp_(-1, 1)
        self.optimiser.step()

        return loss