def train_net()

in experiments/overlap/train_net_jsd.py [0:0]


def train_net(model, optimizer, train_dataset,
        batch_size,
        max_epoch,
        loader_params,
        lr_policy,
        checkpoint_folder='checkpoints',
        name=None,
        save_period=1,
        weights=None,
        num_gpus=1,
        is_leader=True,
        jsd_num=3,
        jsd_alpha=12.0):

        chpk_pre = 'model_epoch_'
        if name is not None:
            chpk_pre = name + "_" + chpk_pre
        chpk_post = '.pyth'
        if os.path.exists(checkpoint_folder):
            checkpoints = [c for c in os.listdir(checkpoint_folder) if chpk_post in c and chpk_pre == "_".join(c.split("_")[:-1]) +"_"]
        else:
            checkpoints = []
        if weights:
            checkpoint = torch.load(weights, map_location='cpu')
            log.info("Pretrained weights provided.  Loading model from {} and skipping training.".format(weights))
            if num_gpus > 1:
                model.module.load_state_dict(checkpoint['model_state'])
            else:
                model.load_state_dict(checkpoint['model_state'])

            return model
        elif checkpoints:
            last_checkpoint_name = os.path.join(checkpoint_folder, sorted(checkpoints)[-1])
            checkpoint = torch.load(last_checkpoint_name, map_location='cpu')
            log.info("Loading model from {}".format(last_checkpoint_name))
            if num_gpus > 1:
                model.module.load_state_dict(checkpoint['model_state'])
            else:
                model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            start_epoch = checkpoint['epoch'] + 1
        else:
            start_epoch = 1

        if train_dataset is None:
            return model

        sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)\
                if num_gpus > 1 else None
        loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=batch_size,
                shuffle=True if sampler is None else False,
                sampler=sampler,
                num_workers=loader_params.num_workers,
                pin_memory=loader_params.pin_memory,
                drop_last=True
                )

        for i in range(start_epoch, max_epoch+1):
            log.info("Starting epoch {}/{}".format(i, max_epoch))
            time_start = time.time()
            if sampler:
                sampler.set_epoch(i)
            train_epoch(model, optimizer, loader, lr_policy, i, num_gpus, jsd_num=jsd_num, jsd_alpha=jsd_alpha)
            time_stop = time.time()
            seconds_taken = (time_stop - time_start)
            eta_td = datetime.timedelta(seconds=int(seconds_taken*(max_epoch-i)))
            log.info("Seconds taken: {:.2f}, Time remaining: {}".format(seconds_taken, eta_str(eta_td)))
            
            
            if (i % save_period == 0 or i == max_epoch) and is_leader:
                if num_gpus > 1:
                    m = model.module
                else:
                    m = model

                checkpoint = {
                        'epoch' : i,
                        'model_state' : m.state_dict(),
                        'optimizer_state' : optimizer.state_dict()
                        }
                checkpoint_file = "{:s}{:04d}{:s}".format(chpk_pre, i, chpk_post)
                if not os.path.exists(checkpoint_folder):
                    os.mkdir(checkpoint_folder)
                checkpoint_file = os.path.join(checkpoint_folder, checkpoint_file)
                log.info("Saving model to {}".format(checkpoint_file))
                torch.save(checkpoint, checkpoint_file)