in multiple_futures_prediction/dataset_ngsim.py [0:0]
def __init__(self, mat_file:str, t_h:int=30, t_f:int=50, d_s:int = 2,
enc_size:int=64, use_gru:bool=False, self_norm:bool=False,
data_aug:bool=False, use_context:bool=False, nbr_search_depth:int= 3,
ds_seed:int=1234) -> None:
self.D = scp.loadmat(mat_file)['traj']
self.T = scp.loadmat(mat_file)['tracks']
self.t_h = t_h # length of track history
self.t_f = t_f # length of predicted trajectory
self.d_s = d_s # down sampling rate of all sequences
self.enc_size = enc_size # size of encoder LSTM
self.grid_size = (13,3) # size of context grid
self.enc_fac = 2 if use_gru else 1
self.self_norm = self_norm
self.data_aug = data_aug
self.noise = np.array([[0.5, 2.0]])
self.dt = 0.1*self.d_s
self.ft_to_m = 0.3048
self.use_context = use_context
if self.use_context:
self.maps = pickle.load(open('data/maps.pkl', 'rb'))
self.nbr_search_depth = nbr_search_depth
cache_file = 'multiple_futures_prediction/ngsim_data/NgsimIndex_%s.p'%os.path.basename(mat_file)
#build index of [dataset (0 based), veh_id_0b, frame(time)] into a dictionary
if not os.path.exists(cache_file):
self.Index = {}
print('building index...')
for i, row in enumerate(self.D):
key = (int(row[0]-1), int(row[1]-1), int(row[2]))
self.Index[key] = i
print('build index done')
pickle.dump( self.Index, open(cache_file,'wb'))
else:
self.Index = pickle.load( open(cache_file,'rb'))
self.ind_random = np.arange(len(self.D))
self.seed = ds_seed
np.random.seed(self.seed)
np.random.shuffle(self.ind_random)