def normalize_rewards()

in train/learner.py [0:0]


def normalize_rewards(flags, train_job_id, rewards, reward_stats):
    """
    Compute the normalized rewards, updating `reward_stats` at the same time.

    :param flags: Config flags.
    :param train_job_id: Tensor (T, N) of integers providing the training job ID
        associated to the corresponding reward in `rewards`.
    :param rewards: Tensor (T, N) of observed rewards. This tensor is updated
        in-place.
    :param reward_stats: The dictionay holding reward stats.
        If `reward_normalization_stats_per_job` is False, it has a single key (0),
        and otherwise it has one key per train job ID. Each key is associated to
        a dictionary with the following key / value pairs:
            - Before `1/reward_normalization_coeff` rewards have been seen (for the
              corresponding key):
                * "initial_rewards": a tensor containing all rewards since so far
            - After `1/reward_normalization_coeff` rewards have been seen:
                * "offset": applied to rewards to shift them
                * "mean": running mean of (shifted) rewards
                * "mean_squared": running mean of (shifted) squared rewards
    """
    rewards_flat = rewards.flatten()
    is_job_id = None
    if flags.reward_normalization_stats_per_job:
        # Split rewards according to their train job ID.
        train_job_id_flat = train_job_id.flatten()  # [T * N]
        all_job_ids = torch.unique(train_job_id_flat).view(-1, 1)  # [M, 1]
        if len(all_job_ids) == 1:
            # Optimization if there is a single job ID.
            rewards_per_job = {all_job_ids[0].item(): rewards_flat}
        else:
            is_job_id = train_job_id_flat == all_job_ids  # [M, T * N]
            rewards_per_job = {
                job_id.item(): rewards_flat[match]
                for job_id, match in zip(all_job_ids, is_job_id)
            }
    else:
        rewards_per_job = {0: rewards_flat}

    normalized_rewards = []
    for job_id, job_rewards in rewards_per_job.items():
        job_stats = reward_stats[job_id]
        if "initial_rewards" in job_stats:
            # Use regular mean / std instead of moving averages.
            initial_rewards = job_stats["initial_rewards"] = torch.cat(
                (job_stats["initial_rewards"], job_rewards)
            )
            mean_rewards = initial_rewards.mean()
            std_rewards = max(
                math.sqrt(flags.reward_normalization_var_eps), initial_rewards.std()
            )
            n_init = int(1 / flags.reward_normalization_coeff + 0.5)
            if len(initial_rewards) >= n_init:
                # Switch to moving averages.
                del job_stats["initial_rewards"]
                job_stats["mean"] = 0.0
                job_stats["offset"] = mean_rewards
                job_stats["mean_squared"] = std_rewards ** 2
                logging.info(
                    f"Reward statistics (job_id={job_id}) after observing the first "
                    f"{len(initial_rewards)} rewards: mean = {mean_rewards}, std = {std_rewards}"
                )
        else:
            # Use moving averages.
            mean_rewards = job_stats["mean"]
            mean_squared_rewards = job_stats["mean_squared"]
            offset = job_stats["offset"]

            # Shift rewards by offset.
            rewards_shifted = job_rewards - offset
            shifted_mean = rewards_shifted.mean()

            # Update moving averages.
            n = len(rewards_shifted)
            alpha = (1 - flags.reward_normalization_coeff) ** n
            mean_rewards = job_stats["mean"] = (
                alpha * mean_rewards + (1 - alpha) * shifted_mean
            )
            shifted_squared_mean = (rewards_shifted ** 2).mean()
            mean_squared_rewards = job_stats["mean_squared"] = (
                alpha * mean_squared_rewards + (1 - alpha) * shifted_squared_mean
            )
            var_rewards = max(
                flags.reward_normalization_var_eps,
                mean_squared_rewards - mean_rewards ** 2,
            )

            # Compute the mean / std used for reward normalization.
            mean_rewards = offset + mean_rewards
            std_rewards = math.sqrt(var_rewards)

        # Apply the reward normalization formula.
        new_job_rewards = (job_rewards - mean_rewards) * (
            math.sqrt(1 - flags.discounting ** 2) / std_rewards
        )
        assert new_job_rewards.dtype == torch.float32
        normalized_rewards.append(new_job_rewards)

    # Update the `rewards` input tensor with normalized rewards.
    if len(normalized_rewards) == 1:
        # Single job ID.
        rewards_flat[:] = normalized_rewards[0]
    else:
        assert is_job_id is not None
        for job_normalized_rewards, match in zip(normalized_rewards, is_job_id):
            rewards_flat[match] = job_normalized_rewards