in activemri/baselines/ddqn.py [0:0]
def _train_dqn_policy(self):
""" Trains the DQN policy. """
self.logger.info(
f"Starting training at step {self.steps}/{self.options.num_train_steps}. "
f"Best score so far is {self.best_test_score}."
)
steps_epsilon = self.steps
while self.steps < self.options.num_train_steps:
self.logger.info("Episode {}".format(self.episode + 1))
# Evaluate the current policy
if self.options.dqn_test_episode_freq and (
self.episode % self.options.dqn_test_episode_freq == 0
):
test_scores, _ = evaluation.evaluate(
self.env,
self.policy,
self.options.num_test_episodes,
self.options.seed,
"val",
)
self.env.set_training()
auc_score = test_scores[self.options.reward_metric].sum(axis=1).mean()
if "mse" in self.options.reward_metric:
auc_score *= -1
if auc_score > self.best_test_score:
policy_path = os.path.join(
self.options.checkpoints_dir, "policy_best.pt"
)
self.save(policy_path)
self.best_test_score = auc_score
self.logger.info(
f"Saved DQN model with score {self.best_test_score} to {policy_path}."
)
# Save model periodically
if self.episode % self.options.freq_dqn_checkpoint_save == 0:
self.checkpoint(save_memory=False)
# Run an episode and update model
obs, meta = self.env.reset()
msg = ", ".join(
[
f"({meta['fname'][i]}, {meta['slice_id'][i]})"
for i in range(len(meta["slice_id"]))
]
)
self.logger.info(f"Episode started with images {msg}.")
all_done = False
total_reward = 0
auc_score = 0
while not all_done:
epsilon = _get_epsilon(steps_epsilon, self.options)
action = self.policy.get_action(obs, eps_threshold=epsilon)
next_obs, reward, done, meta = self.env.step(action)
auc_score += meta["current_score"][self.options.reward_metric]
all_done = all(done)
self.steps += 1
obs_tensor = _encode_obs_dict(obs)
next_obs_tensor = _encode_obs_dict(next_obs)
batch_size = len(obs_tensor)
for i in range(batch_size):
self.policy.add_experience(
obs_tensor[i], action[i], next_obs_tensor[i], reward[i], done[i]
)
update_results = self.policy.update_parameters(self.target_net)
torch.cuda.empty_cache()
if self.steps % self.options.target_net_update_freq == 0:
self.logger.info("Updating target network.")
self.target_net.load_state_dict(self.policy.state_dict())
steps_epsilon += 1
# Adding per-step tensorboard logs
if self.steps % 250 == 0:
self.logger.debug("Writing to tensorboard.")
self.writer.add_scalar("epsilon", epsilon, self.steps)
if update_results is not None:
self.writer.add_scalar(
"loss", update_results["loss"], self.steps
)
self.writer.add_scalar(
"grad_norm", update_results["grad_norm"], self.steps
)
self.writer.add_scalar(
"mean_q_value", update_results["q_values_mean"], self.steps
)
self.writer.add_scalar(
"std_q_value", update_results["q_values_std"], self.steps
)
total_reward += reward
obs = next_obs
# Adding per-episode tensorboard logs
total_reward = total_reward.mean().item()
auc_score = auc_score.mean().item()
self.reward_images_in_window[self.episode % self.window_size] = total_reward
self.current_score_auc_window[self.episode % self.window_size] = auc_score
self.writer.add_scalar("episode_reward", total_reward, self.episode)
self.writer.add_scalar(
"average_reward_images_in_window",
np.sum(self.reward_images_in_window)
/ min(self.episode + 1, self.window_size),
self.episode,
)
self.writer.add_scalar(
"average_auc_score_in_window",
np.sum(self.current_score_auc_window)
/ min(self.episode + 1, self.window_size),
self.episode,
)
self.episode += 1
self.checkpoint()
# Writing DONE file with best test score
with _get_folder_lock(self.folder_lock_path):
with open(
DDQNTrainer.get_done_filename(self.options.checkpoints_dir), "w"
) as f:
f.write(str(self.best_test_score))
return self.best_test_score