def _train_dqn_policy()

in activemri/baselines/ddqn.py [0:0]


    def _train_dqn_policy(self):
        """ Trains the DQN policy. """
        self.logger.info(
            f"Starting training at step {self.steps}/{self.options.num_train_steps}. "
            f"Best score so far is {self.best_test_score}."
        )

        steps_epsilon = self.steps
        while self.steps < self.options.num_train_steps:
            self.logger.info("Episode {}".format(self.episode + 1))

            # Evaluate the current policy
            if self.options.dqn_test_episode_freq and (
                self.episode % self.options.dqn_test_episode_freq == 0
            ):
                test_scores, _ = evaluation.evaluate(
                    self.env,
                    self.policy,
                    self.options.num_test_episodes,
                    self.options.seed,
                    "val",
                )
                self.env.set_training()
                auc_score = test_scores[self.options.reward_metric].sum(axis=1).mean()
                if "mse" in self.options.reward_metric:
                    auc_score *= -1
                if auc_score > self.best_test_score:
                    policy_path = os.path.join(
                        self.options.checkpoints_dir, "policy_best.pt"
                    )
                    self.save(policy_path)
                    self.best_test_score = auc_score
                    self.logger.info(
                        f"Saved DQN model with score {self.best_test_score} to {policy_path}."
                    )

            # Save model periodically
            if self.episode % self.options.freq_dqn_checkpoint_save == 0:
                self.checkpoint(save_memory=False)

            # Run an episode and update model
            obs, meta = self.env.reset()
            msg = ", ".join(
                [
                    f"({meta['fname'][i]}, {meta['slice_id'][i]})"
                    for i in range(len(meta["slice_id"]))
                ]
            )
            self.logger.info(f"Episode started with images {msg}.")
            all_done = False
            total_reward = 0
            auc_score = 0
            while not all_done:
                epsilon = _get_epsilon(steps_epsilon, self.options)
                action = self.policy.get_action(obs, eps_threshold=epsilon)
                next_obs, reward, done, meta = self.env.step(action)
                auc_score += meta["current_score"][self.options.reward_metric]
                all_done = all(done)
                self.steps += 1
                obs_tensor = _encode_obs_dict(obs)
                next_obs_tensor = _encode_obs_dict(next_obs)
                batch_size = len(obs_tensor)
                for i in range(batch_size):
                    self.policy.add_experience(
                        obs_tensor[i], action[i], next_obs_tensor[i], reward[i], done[i]
                    )

                update_results = self.policy.update_parameters(self.target_net)
                torch.cuda.empty_cache()
                if self.steps % self.options.target_net_update_freq == 0:
                    self.logger.info("Updating target network.")
                    self.target_net.load_state_dict(self.policy.state_dict())
                steps_epsilon += 1

                # Adding per-step tensorboard logs
                if self.steps % 250 == 0:
                    self.logger.debug("Writing to tensorboard.")
                    self.writer.add_scalar("epsilon", epsilon, self.steps)
                    if update_results is not None:
                        self.writer.add_scalar(
                            "loss", update_results["loss"], self.steps
                        )
                        self.writer.add_scalar(
                            "grad_norm", update_results["grad_norm"], self.steps
                        )
                        self.writer.add_scalar(
                            "mean_q_value", update_results["q_values_mean"], self.steps
                        )
                        self.writer.add_scalar(
                            "std_q_value", update_results["q_values_std"], self.steps
                        )

                total_reward += reward
                obs = next_obs

            # Adding per-episode tensorboard logs
            total_reward = total_reward.mean().item()
            auc_score = auc_score.mean().item()
            self.reward_images_in_window[self.episode % self.window_size] = total_reward
            self.current_score_auc_window[self.episode % self.window_size] = auc_score
            self.writer.add_scalar("episode_reward", total_reward, self.episode)
            self.writer.add_scalar(
                "average_reward_images_in_window",
                np.sum(self.reward_images_in_window)
                / min(self.episode + 1, self.window_size),
                self.episode,
            )
            self.writer.add_scalar(
                "average_auc_score_in_window",
                np.sum(self.current_score_auc_window)
                / min(self.episode + 1, self.window_size),
                self.episode,
            )

            self.episode += 1

        self.checkpoint()

        # Writing DONE file with best test score
        with _get_folder_lock(self.folder_lock_path):
            with open(
                DDQNTrainer.get_done_filename(self.options.checkpoints_dir), "w"
            ) as f:
                f.write(str(self.best_test_score))

        return self.best_test_score