in imnet_extract/train.py [0:0]
def _init_state(self) -> None:
"""
Initialize the state and load it from an existing checkpoint if any
"""
torch.manual_seed(0)
np.random.seed(0)
print("Create data loaders", flush=True)
print("Input size : "+str(self._train_cfg.input_size))
print("Model : " + str(self._train_cfg.architecture) )
backbone_architecture=None
if self._train_cfg.architecture=='PNASNet' :
backbone_architecture='pnasnet5large'
transformation=get_transforms(input_size=self._train_cfg.input_size,test_size=self._train_cfg.input_size, kind='full', crop=True, need=('train', 'val'), backbone=backbone_architecture)
transform_test = transformation['val']
test_set = datasets.ImageFolder(self._train_cfg.dataset_path,transform=transform_test)
self._test_loader = torch.utils.data.DataLoader(
test_set, batch_size=self._train_cfg.batch_per_gpu, shuffle=False, num_workers=(self._train_cfg.workers-1),
)
print("Create distributed model", flush=True)
if self._train_cfg.architecture=='PNASNet' :
model= pnasnet5large(pretrained='imagenet')
if self._train_cfg.architecture=='ResNet50' :
model=resnet50(pretrained=False)
if self._train_cfg.architecture=='IGAM_Resnext101_32x48d' :
model=resnext101_32x48d_wsl(progress=True)
pretrained_dict=torch.load(self._train_cfg.weight_path,map_location='cpu')['model']
model_dict = model.state_dict()
count=0
count2=0
for k in model_dict.keys():
count=count+1.0
if(('module.'+k) in pretrained_dict.keys()):
count2=count2+1.0
model_dict[k]=pretrained_dict.get(('module.'+k))
model.load_state_dict(model_dict)
print("load "+str(count2*100/count)+" %")
assert int(count2*100/count)== 100,"model loading error"
for name, child in model.named_children():
for name2, params in child.named_parameters():
params.requires_grad = False
print('model_load')
if torch.cuda.is_available():
model.cuda(self._train_cfg.local_rank)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[self._train_cfg.local_rank], output_device=self._train_cfg.local_rank
)
self._state = TrainerState(
model=model
)
checkpoint_fn = osp.join(self._train_cfg.save_folder, str(self._train_cfg.job_id), "checkpoint.pth")
if os.path.isfile(checkpoint_fn):
print(f"Load existing checkpoint from {checkpoint_fn}", flush=True)
self._state = TrainerState.load(checkpoint_fn, default=self._state)