def init_rbf_state_enc()

in multiple_futures_prediction/model_ngsim.py [0:0]


  def init_rbf_state_enc(self, in_dim: int ) -> None:
    """Initialize the dynamic attentional RBF encoder.
    Args:
      in_dim is the input dim of the observation.
    """
    self.sec_in_dim = in_dim
    self.extra_pos_dim = 2

    self.sec_in_pos_dim     = 2
    self.sec_key_dim        = 8
    self.sec_key_hidden_dim = 32

    self.sec_hidden_dim     = 32
    self.scale = 1.0
    self.slot_key_scale = 1.0
    self.num_slots = 8
    self.slot_keys = []
    
    # Network for computing the 'key'
    self.sec_key_net = torch.nn.Sequential( #type: ignore
                          torch.nn.Linear(self.sec_in_dim+self.extra_pos_dim, self.sec_key_hidden_dim),
                          torch.nn.ReLU(),
                          torch.nn.Linear(self.sec_key_hidden_dim, self.sec_key_dim)
                       )

    for ss in range(self.num_slots):
      self.slot_keys.append( torch.nn.Parameter( self.slot_key_scale*torch.randn( self.sec_key_dim, 1, dtype=torch.float32) ) ) #type: ignore
    self.slot_keys = torch.nn.ParameterList( self.slot_keys )  # type: ignore

    # Network for encoding a scene-level contextual feature.
    self.sec_hist_net   = torch.nn.Sequential( #type: ignore
                          torch.nn.Linear(self.sec_in_dim*self.num_slots, self.sec_hidden_dim),
                          torch.nn.ReLU(),
                          torch.nn.Linear(self.sec_hidden_dim, self.sec_hidden_dim),
                          torch.nn.ReLU(),
                          torch.nn.Linear(self.sec_hidden_dim, self.st_enc_hist_size)
                        )

    # Encoder position of other's into a feature network, input should be normalized to ref_pos.
    self.sec_pos_net = torch.nn.Sequential( #type: ignore
                          torch.nn.Linear(self.sec_in_pos_dim*self.num_slots, self.sec_hidden_dim),
                          torch.nn.ReLU(),
                          torch.nn.Linear(self.sec_hidden_dim, self.sec_hidden_dim),
                          torch.nn.ReLU(),
                          torch.nn.Linear(self.sec_hidden_dim, self.st_enc_pos_size)
                        )