lib/scaled_mse_head.py (29 lines of code) (raw):

from typing import Dict, Optional import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init from lib.action_head import fan_in_linear from lib.normalize_ewma import NormalizeEwma class ScaledMSEHead(nn.Module): """ Linear output layer that scales itself so that targets are always normalized to N(0, 1) """ def __init__( self, input_size: int, output_size: int, norm_type: Optional[str] = "ewma", norm_kwargs: Optional[Dict] = None ): super().__init__() self.input_size = input_size self.output_size = output_size self.norm_type = norm_type self.linear = nn.Linear(self.input_size, self.output_size) norm_kwargs = {} if norm_kwargs is None else norm_kwargs self.normalizer = NormalizeEwma(output_size, **norm_kwargs) def reset_parameters(self): init.orthogonal_(self.linear.weight) fan_in_linear(self.linear) self.normalizer.reset_parameters() def forward(self, input_data): return self.linear(input_data) def loss(self, prediction, target): """ Calculate the MSE loss between output and a target. 'Prediction' has to be normalized while target is denormalized. Loss is calculated in a 'normalized' space. """ return F.mse_loss(prediction, self.normalizer(target), reduction="mean") def denormalize(self, input_data): """Convert input value from a normalized space into the original one""" return self.normalizer.denormalize(input_data) def normalize(self, input_data): return self.normalizer(input_data)