in multiple_futures_prediction/model_ngsim.py [0:0]
def init_rbf_state_enc(self, in_dim: int ) -> None:
"""Initialize the dynamic attentional RBF encoder.
Args:
in_dim is the input dim of the observation.
"""
self.sec_in_dim = in_dim
self.extra_pos_dim = 2
self.sec_in_pos_dim = 2
self.sec_key_dim = 8
self.sec_key_hidden_dim = 32
self.sec_hidden_dim = 32
self.scale = 1.0
self.slot_key_scale = 1.0
self.num_slots = 8
self.slot_keys = []
# Network for computing the 'key'
self.sec_key_net = torch.nn.Sequential( #type: ignore
torch.nn.Linear(self.sec_in_dim+self.extra_pos_dim, self.sec_key_hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(self.sec_key_hidden_dim, self.sec_key_dim)
)
for ss in range(self.num_slots):
self.slot_keys.append( torch.nn.Parameter( self.slot_key_scale*torch.randn( self.sec_key_dim, 1, dtype=torch.float32) ) ) #type: ignore
self.slot_keys = torch.nn.ParameterList( self.slot_keys ) # type: ignore
# Network for encoding a scene-level contextual feature.
self.sec_hist_net = torch.nn.Sequential( #type: ignore
torch.nn.Linear(self.sec_in_dim*self.num_slots, self.sec_hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(self.sec_hidden_dim, self.sec_hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(self.sec_hidden_dim, self.st_enc_hist_size)
)
# Encoder position of other's into a feature network, input should be normalized to ref_pos.
self.sec_pos_net = torch.nn.Sequential( #type: ignore
torch.nn.Linear(self.sec_in_pos_dim*self.num_slots, self.sec_hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(self.sec_hidden_dim, self.sec_hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(self.sec_hidden_dim, self.st_enc_pos_size)
)