in train/learner.py [0:0]
def normalize_rewards(flags, train_job_id, rewards, reward_stats):
"""
Compute the normalized rewards, updating `reward_stats` at the same time.
:param flags: Config flags.
:param train_job_id: Tensor (T, N) of integers providing the training job ID
associated to the corresponding reward in `rewards`.
:param rewards: Tensor (T, N) of observed rewards. This tensor is updated
in-place.
:param reward_stats: The dictionay holding reward stats.
If `reward_normalization_stats_per_job` is False, it has a single key (0),
and otherwise it has one key per train job ID. Each key is associated to
a dictionary with the following key / value pairs:
- Before `1/reward_normalization_coeff` rewards have been seen (for the
corresponding key):
* "initial_rewards": a tensor containing all rewards since so far
- After `1/reward_normalization_coeff` rewards have been seen:
* "offset": applied to rewards to shift them
* "mean": running mean of (shifted) rewards
* "mean_squared": running mean of (shifted) squared rewards
"""
rewards_flat = rewards.flatten()
is_job_id = None
if flags.reward_normalization_stats_per_job:
# Split rewards according to their train job ID.
train_job_id_flat = train_job_id.flatten() # [T * N]
all_job_ids = torch.unique(train_job_id_flat).view(-1, 1) # [M, 1]
if len(all_job_ids) == 1:
# Optimization if there is a single job ID.
rewards_per_job = {all_job_ids[0].item(): rewards_flat}
else:
is_job_id = train_job_id_flat == all_job_ids # [M, T * N]
rewards_per_job = {
job_id.item(): rewards_flat[match]
for job_id, match in zip(all_job_ids, is_job_id)
}
else:
rewards_per_job = {0: rewards_flat}
normalized_rewards = []
for job_id, job_rewards in rewards_per_job.items():
job_stats = reward_stats[job_id]
if "initial_rewards" in job_stats:
# Use regular mean / std instead of moving averages.
initial_rewards = job_stats["initial_rewards"] = torch.cat(
(job_stats["initial_rewards"], job_rewards)
)
mean_rewards = initial_rewards.mean()
std_rewards = max(
math.sqrt(flags.reward_normalization_var_eps), initial_rewards.std()
)
n_init = int(1 / flags.reward_normalization_coeff + 0.5)
if len(initial_rewards) >= n_init:
# Switch to moving averages.
del job_stats["initial_rewards"]
job_stats["mean"] = 0.0
job_stats["offset"] = mean_rewards
job_stats["mean_squared"] = std_rewards ** 2
logging.info(
f"Reward statistics (job_id={job_id}) after observing the first "
f"{len(initial_rewards)} rewards: mean = {mean_rewards}, std = {std_rewards}"
)
else:
# Use moving averages.
mean_rewards = job_stats["mean"]
mean_squared_rewards = job_stats["mean_squared"]
offset = job_stats["offset"]
# Shift rewards by offset.
rewards_shifted = job_rewards - offset
shifted_mean = rewards_shifted.mean()
# Update moving averages.
n = len(rewards_shifted)
alpha = (1 - flags.reward_normalization_coeff) ** n
mean_rewards = job_stats["mean"] = (
alpha * mean_rewards + (1 - alpha) * shifted_mean
)
shifted_squared_mean = (rewards_shifted ** 2).mean()
mean_squared_rewards = job_stats["mean_squared"] = (
alpha * mean_squared_rewards + (1 - alpha) * shifted_squared_mean
)
var_rewards = max(
flags.reward_normalization_var_eps,
mean_squared_rewards - mean_rewards ** 2,
)
# Compute the mean / std used for reward normalization.
mean_rewards = offset + mean_rewards
std_rewards = math.sqrt(var_rewards)
# Apply the reward normalization formula.
new_job_rewards = (job_rewards - mean_rewards) * (
math.sqrt(1 - flags.discounting ** 2) / std_rewards
)
assert new_job_rewards.dtype == torch.float32
normalized_rewards.append(new_job_rewards)
# Update the `rewards` input tensor with normalized rewards.
if len(normalized_rewards) == 1:
# Single job ID.
rewards_flat[:] = normalized_rewards[0]
else:
assert is_job_id is not None
for job_normalized_rewards, match in zip(normalized_rewards, is_job_id):
rewards_flat[match] = job_normalized_rewards