def train()

in train.py [0:0]


def train(gpu_id, ngpus_per_node, args): 

    save_dir = args.save_dir

    # get globale rank (thread id):
    rank = args.node_id * ngpus_per_node + gpu_id

    print(f"Running on rank {rank}.")

    # Initializes ddp:
    if args.ddp:
        setup(rank, ngpus_per_node, args)

    # intialize device:
    device = gpu_id if args.ddp else 'cuda'

    synthesis_ood_flag = args.ood_aux_dataset in ['VOS', 'NPOS']
    require_feats_flag = 'maha' in args.ood_metric
    num_classes, train_loader, test_loader, ood_loader, train_sampler, img_num_per_cls_and_ood = build_dataset(args, ngpus_per_node)
    img_num_per_cls = img_num_per_cls_and_ood[:num_classes]

    model, optimizer, scheduler, num_outputs = build_model(args, num_classes, device, gpu_id)
    if require_feats_flag:
        model.id_feat_pool = IDFeatPool(num_classes, sample_num=max(img_num_per_cls), 
                                        feat_dim=model.penultimate_layer_dim, device=device)

    adjustments = build_prior(args, model, img_num_per_cls, num_classes, num_outputs, device)

    # train:
    if args.resume:
        # ckpt = torch.load(osp.join(save_dir, 'latest.pth'), map_location='cpu')
        ckpt = torch.load(osp.join(args.resume, 'latest.pth'), map_location='cpu')
        if is_parallel(model):
            ckpt['model'] = {'module.' + k: v for k, v in ckpt['model'].items()}
        model.load_state_dict(ckpt['model'], strict=False)
        try:
            optimizer.load_state_dict(ckpt['optimizer'])
            scheduler.load_state_dict(ckpt['scheduler'])
        except:
            pass 
        start_epoch = ckpt['epoch']+1 
        best_overall_acc = ckpt['best_overall_acc']
        training_losses = ckpt['training_losses']
        test_clean_losses = ckpt['test_clean_losses']
        f1s = ckpt['f1s']
        overall_accs = ckpt['overall_accs']
        many_accs = ckpt['many_accs']
        median_accs = ckpt['median_accs']
        low_accs = ckpt['low_accs']
    else:
        training_losses, test_clean_losses = [], []
        f1s, overall_accs, many_accs, median_accs, low_accs = [], [], [], [], []
        best_overall_acc = 0
        start_epoch = 0
    # print('Resume Done.')

    fp = open(osp.join(save_dir, 'train_log.txt'), 'a+')
    fp_val = open(osp.join(save_dir, 'val_log.txt'), 'a+')
    shutil.copyfile('models/base.py', f'{save_dir}/base.py')
    for epoch in range(start_epoch, args.epochs):
        # reset sampler when using ddp:
        if args.ddp:
            train_sampler.set_epoch(epoch)
        start_time = time.time()

        model.train()
        training_loss_meter = AverageMeter()
        current_lr = scheduler.get_last_lr()
        pbar = zip(train_loader, ood_loader)
        # if args.ddp and rank == 0:
        #     pbar = tqdm(pbar, desc=f'Epoch: {epoch:03d}/{args.epochs:03d}', total=len(train_loader))
        stop_flag = False
        for batch_idx, ((in_data, labels), (ood_data, _)) in enumerate(pbar):
            in_data = torch.cat([in_data[0], in_data[1]], dim=0) # shape=(2*N,C,H,W). Two views of each image.
            in_data, labels = in_data.to(device), labels.to(device)
            ood_data = ood_data.to(device)

            # forward:
            if not synthesis_ood_flag and not require_feats_flag:
                all_data = torch.cat([in_data, ood_data], dim=0) # shape=(2*Nin+Nout,C,W,H)
                in_loss, ood_loss, aux_loss = model(all_data, mode='calc_loss', labels=labels, adjustments=adjustments, args=args)
            elif synthesis_ood_flag:
                in_loss, ood_loss, aux_loss, id_feats = \
                    model(in_data, mode='calc_loss', labels=labels, adjustments=adjustments, args=args, ood_data=ood_data, return_features=True)
                ood_loader.update(id_feats.detach().clone(), labels)
            elif require_feats_flag:
                all_data = torch.cat([in_data, ood_data], dim=0) # shape=(2*Nin+Nout,C,W,H)
                num_ood = len(ood_data)
                in_loss, ood_loss, aux_loss, id_feats = \
                    model(all_data, mode='calc_loss', labels=labels, adjustments=adjustments, args=args, return_features=True)

            loss: torch.Tensor = in_loss + args.Lambda * ood_loss + args.Lambda2 * aux_loss
            if torch.isnan(loss):
                print('Warning: Loss is NaN. Training stopped.')
                stop_flag = True
                break
            if require_feats_flag:
                model.id_feat_pool.update(id_feats[-num_ood:].detach().clone(), torch.cat((labels, labels)))

            # backward:
            optimizer.zero_grad()

            loss.backward()

            optimizer.step()

            # append:
            training_loss_meter.append(loss.item())
            if rank == 0 and batch_idx % 100 == 0:
                train_str = '%s epoch %d batch %d (train): loss %.4f (%.4f, %.4f, %.4f) | lr %s' % (
                    datetime.now().strftime("%D %H:%M:%S"),
                    epoch, batch_idx, loss.item(), in_loss.item(), ood_loss.item(), aux_loss.item(), current_lr) 
                print(train_str)
                fp.write(train_str + '\n')
                fp.flush()

        if stop_flag:
            print('Use the model at epoch', epoch - 1)
            break

        # lr update:
        scheduler.step()

        if rank == 0:
            # eval on clean set:
            model.eval()
            test_acc_meter, test_loss_meter = AverageMeter(), AverageMeter()
            preds_list, labels_list = [], []
            with torch.no_grad():
                for data, labels in test_loader:
                    data, labels = data.to(device), labels.to(device)
                    logits, features = model(data, return_features=True)
                    in_logits = de_parallel(model).parse_logits(logits, features, args.ood_metric, logits.shape[0])[0]
                    pred = in_logits.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                    loss = F.cross_entropy(in_logits, labels)
                    test_acc_meter.append((in_logits.argmax(1) == labels).float().mean().item())
                    test_loss_meter.append(loss.item())
                    preds_list.append(pred)
                    labels_list.append(labels)

            preds = torch.cat(preds_list, dim=0).detach().cpu().numpy().squeeze()
            labels = torch.cat(labels_list, dim=0).detach().cpu().numpy()

            overall_acc = (preds == labels).sum().item() / len(labels)
            f1 = f1_score(labels, preds, average='macro')

            many_acc, median_acc, low_acc, _ = shot_acc(preds, labels, img_num_per_cls, acc_per_cls=True)

            test_clean_losses.append(test_loss_meter.avg)
            f1s.append(f1)
            overall_accs.append(overall_acc)
            many_accs.append(many_acc)
            median_accs.append(median_acc)
            low_accs.append(low_acc)

            val_str = '%s epoch %d (test): ACC %.4f (%.4f, %.4f, %.4f) | F1 %.4f | time %s' % \
                        (datetime.now().strftime("%D %H:%M:%S"), epoch, overall_acc, many_acc, median_acc, low_acc, f1, time.time()-start_time) 
            print(val_str)
            fp_val.write(val_str + '\n')
            fp_val.flush()

            # save curves:
            training_losses.append(training_loss_meter.avg)
            save_curve(args, save_dir, training_losses, test_clean_losses, 
                       overall_accs, many_accs, median_accs, low_accs, f1s)

            # save best model:
            model_state_dict = de_parallel(model).state_dict()
            if overall_accs[-1] > best_overall_acc and epoch >= args.epochs * 0.75:
                best_overall_acc = overall_accs[-1]
                torch.save(model_state_dict, osp.join(save_dir, 'best_clean_acc.pth'))
            
            # save feature pool
            if synthesis_ood_flag:
                ood_loader.save(osp.join(save_dir, 'id_feats.pth'))
            elif require_feats_flag:
                model.id_feat_pool.save(osp.join(save_dir, 'id_feats.pth'))  # exactly the same

            # save pth:
            torch.save({
                'model': model_state_dict,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'epoch': epoch, 
                'best_overall_acc': best_overall_acc,
                'training_losses': training_losses, 
                'test_clean_losses': test_clean_losses, 
                'f1s': f1s, 
                'overall_accs': overall_accs, 
                'many_accs': many_accs, 
                'median_accs': median_accs, 
                'low_accs': low_accs, 
                }, 
                osp.join(save_dir, 'latest.pth'))
            if args.save_epochs > 0 and epoch % args.save_epochs == 0:
                torch.save({
                    'model': model_state_dict,
                    'optimizer': optimizer.state_dict(),
                    }, osp.join(save_dir, f'epoch{epoch}.pth'))
                if synthesis_ood_flag:
                    ood_loader.save(osp.join(save_dir, f'id_feats_epoch{epoch}.pth'))
                elif require_feats_flag:
                    model.id_feat_pool.save(osp.join(save_dir, f'id_feats_epoch{epoch}.pth'))  # exactly the same
    
    # Clean up ddp:
    if args.ddp:
        cleanup()