in multiple_futures_prediction/train_ngsim.py [0:0]
def train( params: AttrDict ) -> Any :
"""Main training function."""
torch.manual_seed( params.seed ) #type: ignore
np.random.seed( params.seed )
############################
batch_size = 1
data_hz = 10
ns_between_samples = (1.0/data_hz)*1e9
d_s = params.subsampling
t_h = params.hist_len_orig_hz
t_f = params.fut_len_orig_hz
NUM_WORKERS = 1
DATA_PATH = 'multiple_futures_prediction/'
# Loading the dataset.
train_set = NgsimDataset(DATA_PATH + 'ngsim_data/TrainSet.mat', t_h, t_f, d_s, params.encoder_size, params.use_gru, params.self_norm,
params.data_aug, params.use_context, params.nbr_search_depth)
val_set = NgsimDataset(DATA_PATH + 'ngsim_data/ValSet.mat', t_h, t_f, d_s, params.encoder_size, params.use_gru, params.self_norm,
params.data_aug, params.use_context, params.nbr_search_depth)
test_set = NgsimDataset(DATA_PATH + 'ngsim_data/TestSet.mat', t_h, t_f, d_s, params.encoder_size, params.use_gru, params.self_norm,
params.data_aug, params.use_context, params.nbr_search_depth)
train_data_loader = DataLoader(train_set,batch_size=batch_size, shuffle=1, num_workers=NUM_WORKERS, collate_fn=train_set.collate_fn,drop_last=True) # type: ignore
val_data_loader = DataLoader(val_set,batch_size=batch_size,shuffle=0, num_workers=NUM_WORKERS, collate_fn=val_set.collate_fn,drop_last=True) #type: ignore
test_data_loader = DataLoader(test_set,batch_size=batch_size, shuffle=0, num_workers=NUM_WORKERS,collate_fn=test_set.collate_fn,drop_last=False) #type: ignore
# Compute or load existing mean over future trajectories.
if os.path.exists(DATA_PATH+'ngsim_data/y_mean.pkl'):
y_mean = pickle.load( open(DATA_PATH+'ngsim_data/y_mean.pkl', 'rb'))
else:
y_mean = get_mean(train_data_loader)
pickle.dump( y_mean, open(DATA_PATH+'ngsim_data/y_mean.pkl', 'wb') )
# Initialize network
net = mfpNet( params )
if params.use_cuda:
net = net.cuda() #type: ignore
net.y_mean = y_mean
y_mean = torch.tensor(net.y_mean)
if params.log:
logger_file, logging_dir = setup_logger(DATA_PATH+"./checkpts/", 'NGSIM' )
train_loss: List = []
val_loss: List = []
MODE='Pre' # For efficiency, we first pre-train w/o interactive rollouts.
num_updates = 0
optimizer = None
for epoch_num in range(20):
if MODE == 'EndPre':
MODE = 'Train'
print('Training with interactive rollouts.')
bStepByStep = True
else:
print('Pre-training without interactive rollouts.')
bStepByStep = False
# Average losses.
avg_tr_loss = 0.
avg_tr_time = 0.
loss_counter = 0.0
for i, data in enumerate(train_data_loader):
if num_updates > params.pre_train_num_updates and MODE == 'Pre':
MODE = 'EndPre'
break
lr_fac = np.power(0.1, num_updates // params.updates_div_by_10 )
lr = max( params.min_lr, params.lr_init*lr_fac)
if optimizer is None:
optimizer = torch.optim.Adam(net.parameters(), lr=lr) #type: ignore
elif lr != optimizer.defaults['lr']:
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
st_time = time.time()
hist, nbrs, mask, fut, mask, context, nbrs_info = data
if params.remove_y_mean:
fut = fut-y_mean.unsqueeze(1)
if params.use_cuda:
hist = hist.cuda()
nbrs = nbrs.cuda()
mask = mask.cuda()
fut = fut.cuda()
mask = mask.cuda()
if context is not None:
context = context.cuda()
# Forward pass.
fut_preds, modes_pred = net.forward_mfp(hist, nbrs, mask, context, nbrs_info, fut, bStepByStep)
if params.modes == 1:
l = nll_loss(fut_preds[0], fut, mask)
else:
l = nll_loss_multimodes(fut_preds, fut, mask, modes_pred) # type: ignore
# Backprop.
optimizer.zero_grad()
l.backward()
torch.nn.utils.clip_grad_norm_(net.parameters(), 10) #type: ignore
optimizer.step()
num_updates += 1
batch_time = time.time()-st_time
avg_tr_loss += l.item()
avg_tr_time += batch_time
effective_batch_sz = float(hist.shape[1])
if num_updates % params.iter_per_err == params.iter_per_err-1:
print("Epoch no:",epoch_num,"update:",num_updates, "| Avg train loss:",
format(avg_tr_loss/100,'0.4f'), " learning_rate:%.5f"%lr)
train_loss.append(avg_tr_loss/100)
if params.log:
msg_str_ = ("Epoch no:",epoch_num,"update:",num_updates, "| Avg train loss:",
format(avg_tr_loss/100,'0.4f'), " learning_rate:%.5f"%lr)
msg_str = str([str(ss) for ss in msg_str_])
logger_file.write(msg_str+'\n')
logger_file.flush()
avg_tr_loss = 0.
if num_updates % params.iter_per_eval == params.iter_per_eval-1:
print("Starting eval")
val_nll_err = eval( 'nll', net, params, val_data_loader, bStepByStep,
use_forcing=params.use_forcing, y_mean=y_mean,
num_batches=500, dataset_name='val_dl nll')
if params.log:
logger_file.write('val nll: ' + str(val_nll_err)+'\n')
logger_file.flush()
# Save weights.
if params.log and num_updates % params.iters_per_save == params.iters_per_save-1:
msg_str = '\nSaving state, update iter:%d %s'%(num_updates, logging_dir)
print(msg_str)
logger_file.write( msg_str ); logger_file.flush()
torch.save(net.state_dict(), logging_dir + '/checkpoints/ngsim_%06d'%num_updates + '.pth') #type: ignore