experiments/overlap/train_net.py [33:82]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        chpk_pre = 'model_epoch_'
        if name is not None:
            chpk_pre = name + "_" + chpk_pre
        chpk_post = '.pyth'
        if os.path.exists(checkpoint_folder):
            checkpoints = [c for c in os.listdir(checkpoint_folder) if chpk_post in c and chpk_pre == "_".join(c.split("_")[:-1]) +"_"]
        else:
            checkpoints = []
        if weights:
            checkpoint = torch.load(weights, map_location='cpu')
            log.info("Pretrained weights provided.  Loading model from {} and skipping training.".format(weights))
            if num_gpus > 1:
                model.module.load_state_dict(checkpoint['model_state'])
            else:
                model.load_state_dict(checkpoint['model_state'])

            return model
        elif checkpoints:
            last_checkpoint_name = os.path.join(checkpoint_folder, sorted(checkpoints)[-1])
            checkpoint = torch.load(last_checkpoint_name, map_location='cpu')
            log.info("Loading model from {}".format(last_checkpoint_name))
            if num_gpus > 1:
                model.module.load_state_dict(checkpoint['model_state'])
            else:
                model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            start_epoch = checkpoint['epoch'] + 1
        else:
            start_epoch = 1

        if train_dataset is None:
            return model

        sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)\
                if num_gpus > 1 else None
        loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=batch_size,
                shuffle=True if sampler is None else False,
                sampler=sampler,
                num_workers=loader_params.num_workers,
                pin_memory=loader_params.pin_memory,
                drop_last=True
                )

        for i in range(start_epoch, max_epoch+1):
            log.info("Starting epoch {}/{}".format(i, max_epoch))
            time_start = time.time()
            if sampler:
                sampler.set_epoch(i)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



experiments/overlap/train_net_jsd.py [37:86]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        chpk_pre = 'model_epoch_'
        if name is not None:
            chpk_pre = name + "_" + chpk_pre
        chpk_post = '.pyth'
        if os.path.exists(checkpoint_folder):
            checkpoints = [c for c in os.listdir(checkpoint_folder) if chpk_post in c and chpk_pre == "_".join(c.split("_")[:-1]) +"_"]
        else:
            checkpoints = []
        if weights:
            checkpoint = torch.load(weights, map_location='cpu')
            log.info("Pretrained weights provided.  Loading model from {} and skipping training.".format(weights))
            if num_gpus > 1:
                model.module.load_state_dict(checkpoint['model_state'])
            else:
                model.load_state_dict(checkpoint['model_state'])

            return model
        elif checkpoints:
            last_checkpoint_name = os.path.join(checkpoint_folder, sorted(checkpoints)[-1])
            checkpoint = torch.load(last_checkpoint_name, map_location='cpu')
            log.info("Loading model from {}".format(last_checkpoint_name))
            if num_gpus > 1:
                model.module.load_state_dict(checkpoint['model_state'])
            else:
                model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            start_epoch = checkpoint['epoch'] + 1
        else:
            start_epoch = 1

        if train_dataset is None:
            return model

        sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)\
                if num_gpus > 1 else None
        loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=batch_size,
                shuffle=True if sampler is None else False,
                sampler=sampler,
                num_workers=loader_params.num_workers,
                pin_memory=loader_params.pin_memory,
                drop_last=True
                )

        for i in range(start_epoch, max_epoch+1):
            log.info("Starting epoch {}/{}".format(i, max_epoch))
            time_start = time.time()
            if sampler:
                sampler.set_epoch(i)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



