def train()

in multiple_futures_prediction/train_ngsim.py [0:0]


def train( params: AttrDict ) -> Any :
  """Main training function."""
  torch.manual_seed( params.seed ) #type: ignore
  np.random.seed( params.seed )

  ############################  
  batch_size = 1
  data_hz                = 10
  ns_between_samples     = (1.0/data_hz)*1e9
  d_s = params.subsampling
  t_h = params.hist_len_orig_hz
  t_f = params.fut_len_orig_hz
  NUM_WORKERS = 1
  
  DATA_PATH = 'multiple_futures_prediction/'  

  # Loading the dataset.  
  train_set = NgsimDataset(DATA_PATH + 'ngsim_data/TrainSet.mat', t_h, t_f, d_s, params.encoder_size, params.use_gru, params.self_norm,
                                params.data_aug, params.use_context, params.nbr_search_depth)
  val_set   = NgsimDataset(DATA_PATH + 'ngsim_data/ValSet.mat', t_h, t_f, d_s, params.encoder_size, params.use_gru, params.self_norm,
                                params.data_aug, params.use_context, params.nbr_search_depth)
  test_set  = NgsimDataset(DATA_PATH + 'ngsim_data/TestSet.mat', t_h, t_f, d_s, params.encoder_size, params.use_gru, params.self_norm,
                                params.data_aug, params.use_context, params.nbr_search_depth)
  train_data_loader = DataLoader(train_set,batch_size=batch_size, shuffle=1, num_workers=NUM_WORKERS, collate_fn=train_set.collate_fn,drop_last=True) # type: ignore
  val_data_loader   = DataLoader(val_set,batch_size=batch_size,shuffle=0, num_workers=NUM_WORKERS, collate_fn=val_set.collate_fn,drop_last=True) #type: ignore
  test_data_loader  = DataLoader(test_set,batch_size=batch_size,  shuffle=0, num_workers=NUM_WORKERS,collate_fn=test_set.collate_fn,drop_last=False) #type: ignore

  # Compute or load existing mean over future trajectories.
  if os.path.exists(DATA_PATH+'ngsim_data/y_mean.pkl'):
    y_mean = pickle.load( open(DATA_PATH+'ngsim_data/y_mean.pkl', 'rb'))    
  else:
    y_mean = get_mean(train_data_loader)
    pickle.dump( y_mean, open(DATA_PATH+'ngsim_data/y_mean.pkl', 'wb') )

  # Initialize network
  net = mfpNet( params )
  if params.use_cuda:
    net = net.cuda() #type: ignore
  
  net.y_mean = y_mean
  y_mean = torch.tensor(net.y_mean)

  if params.log:
    logger_file, logging_dir = setup_logger(DATA_PATH+"./checkpts/", 'NGSIM' )
    
  train_loss: List = []
  val_loss: List = []
  
  MODE='Pre'  # For efficiency, we first pre-train w/o interactive rollouts.
  num_updates = 0
  optimizer = None

  for epoch_num in range(20):
    if MODE == 'EndPre':
      MODE = 'Train'
      print('Training with interactive rollouts.')
      bStepByStep = True
    else:
      print('Pre-training without interactive rollouts.')
      bStepByStep = False        

    # Average losses.
    avg_tr_loss = 0.
    avg_tr_time = 0.
    loss_counter = 0.0

    for i, data in enumerate(train_data_loader):
      if num_updates > params.pre_train_num_updates and MODE == 'Pre':
        MODE = 'EndPre'
        break

      lr_fac = np.power(0.1, num_updates // params.updates_div_by_10 )
      lr = max( params.min_lr, params.lr_init*lr_fac)
      if optimizer is None:
        optimizer = torch.optim.Adam(net.parameters(), lr=lr) #type: ignore 
      elif lr != optimizer.defaults['lr']:
        optimizer = torch.optim.Adam(net.parameters(), lr=lr) 
      
      st_time = time.time()
      hist, nbrs, mask, fut, mask, context, nbrs_info = data
      
      if params.remove_y_mean:
        fut = fut-y_mean.unsqueeze(1)
     
      if params.use_cuda:
        hist = hist.cuda()
        nbrs = nbrs.cuda()
        mask = mask.cuda()
        fut = fut.cuda()
        mask = mask.cuda()
        if context is not None:
          context = context.cuda()

      # Forward pass.
      fut_preds, modes_pred = net.forward_mfp(hist, nbrs, mask, context, nbrs_info, fut, bStepByStep)              
      if params.modes == 1:
        l = nll_loss(fut_preds[0], fut, mask)
      else:
        l = nll_loss_multimodes(fut_preds, fut, mask, modes_pred) # type: ignore

      # Backprop.
      optimizer.zero_grad()
      l.backward()
      torch.nn.utils.clip_grad_norm_(net.parameters(), 10)  #type: ignore    
      optimizer.step()        
      num_updates += 1

      batch_time = time.time()-st_time
      avg_tr_loss += l.item() 
      avg_tr_time += batch_time

      effective_batch_sz = float(hist.shape[1])
      if num_updates % params.iter_per_err == params.iter_per_err-1:            
        print("Epoch no:",epoch_num,"update:",num_updates, "| Avg train loss:",
                format(avg_tr_loss/100,'0.4f'), " learning_rate:%.5f"%lr)
        train_loss.append(avg_tr_loss/100)
        
        if params.log:
          msg_str_ = ("Epoch no:",epoch_num,"update:",num_updates, "| Avg train loss:",
                      format(avg_tr_loss/100,'0.4f'), " learning_rate:%.5f"%lr) 
          msg_str = str([str(ss) for ss in msg_str_]) 
          logger_file.write(msg_str+'\n') 
          logger_file.flush()

        avg_tr_loss = 0.
        if num_updates % params.iter_per_eval == params.iter_per_eval-1:
          print("Starting eval")                
          val_nll_err = eval(  'nll', net, params, val_data_loader, bStepByStep,
                               use_forcing=params.use_forcing, y_mean=y_mean, 
                               num_batches=500, dataset_name='val_dl nll')
          
          if params.log:
            logger_file.write('val nll: ' + str(val_nll_err)+'\n')
            logger_file.flush()

      # Save weights.
      if params.log and num_updates % params.iters_per_save == params.iters_per_save-1:
        msg_str = '\nSaving state, update iter:%d %s'%(num_updates, logging_dir)
        print(msg_str)
        logger_file.write( msg_str ); logger_file.flush()
        torch.save(net.state_dict(), logging_dir + '/checkpoints/ngsim_%06d'%num_updates + '.pth') #type: ignore