in imnet_finetune/train.py [0:0]
def _init_state(self) -> None:
"""
Initialize the state and load it from an existing checkpoint if any
"""
torch.manual_seed(0)
np.random.seed(0)
print("Create data loaders", flush=True)
Input_size_Image=self._train_cfg.input_size
print("Input size : "+str(Input_size_Image))
print("Model : " + str(self._train_cfg.architecture) )
backbone_architecture=None
if self._train_cfg.architecture=='PNASNet' :
backbone_architecture='pnasnet5large'
transformation=get_transforms(input_size=self._train_cfg.input_size,test_size=self._train_cfg.input_size, kind='full', crop=True, need=('train', 'val'), backbone=backbone_architecture)
transform_test = transformation['val']
train_set = datasets.ImageFolder(self._train_cfg.imnet_path+ '/train',transform=transform_test)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_set,num_replicas=self._train_cfg.num_tasks, rank=self._train_cfg.global_rank
)
self._train_loader = torch.utils.data.DataLoader(
train_set,
batch_size=self._train_cfg.batch_per_gpu,
num_workers=(self._train_cfg.workers-1),
sampler=train_sampler,
)
test_set = datasets.ImageFolder(self._train_cfg.imnet_path + '/val',transform=transform_test)
self._test_loader = torch.utils.data.DataLoader(
test_set, batch_size=self._train_cfg.batch_per_gpu, shuffle=False, num_workers=(self._train_cfg.workers-1),
)
print(f"Total batch_size: {self._train_cfg.batch_per_gpu * self._train_cfg.num_tasks}", flush=True)
print("Create distributed model", flush=True)
if self._train_cfg.architecture=='PNASNet' :
model= pnasnet5large(pretrained='imagenet')
if self._train_cfg.architecture=='ResNet50' :
model=models.resnet50(pretrained=False)
pretrained_dict=torch.load(self._train_cfg.resnet_weight_path,map_location='cpu')['model']
model_dict = model.state_dict()
count=0
count2=0
for k in model_dict.keys():
count=count+1.0
if(('module.'+k) in pretrained_dict.keys()):
count2=count2+1.0
model_dict[k]=pretrained_dict.get(('module.'+k))
model.load_state_dict(model_dict)
print("load "+str(count2*100/count)+" %")
assert int(count2*100/count)== 100,"model loading error"
if self._train_cfg.architecture=='IGAM_Resnext101_32x48d' :
model=resnext101_32x48d_wsl(progress=True)
if self._train_cfg.architecture=='PNASNet' :
for name, child in model.named_children():
if 'last_linear' not in name and 'cell_11' not in name and 'cell_10' not in name and 'cell_9' not in name:
for name2, params in child.named_parameters():
params.requires_grad = False
elif not self._train_cfg.architecture=='EfficientNet' :
for name, child in model.named_children():
if 'fc' not in name:
for name2, params in child.named_parameters():
params.requires_grad = False
if self._train_cfg.architecture=='EfficientNet' :
assert has_timm
model = create_model(self._train_cfg.EfficientNet_models,pretrained=False,num_classes=1000) #see https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/efficientnet.py for name
for name, child in model.named_children():
if 'classifier' not in name:
for name2, params in child.named_parameters():
params.requires_grad = False
pretrained_dict=load_state_dict_from_url(default_cfgs[self._train_cfg.EfficientNet_models]['url'],map_location='cpu')
model_dict = model.state_dict()
for k in model_dict.keys():
if(k in pretrained_dict.keys()):
model_dict[k]=pretrained_dict.get(k)
model.load_state_dict(model_dict)
torch.cuda.empty_cache()
model.classifier.requires_grad=True
model.conv_head.requires_grad=True
model.cuda(self._train_cfg.local_rank)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[self._train_cfg.local_rank], output_device=self._train_cfg.local_rank
)
linear_scaled_lr = 8.0 * self._train_cfg.lr * self._train_cfg.batch_per_gpu * self._train_cfg.num_tasks /512.0
optimizer = optim.SGD(model.parameters(), lr=linear_scaled_lr, momentum=0.9,weight_decay=1e-4)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30)
self._state = TrainerState(
epoch=0,accuracy=0.0, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler
)
checkpoint_fn = osp.join(self._train_cfg.save_folder, str(self._train_cfg.job_id), "checkpoint.pth")
if os.path.isfile(checkpoint_fn):
print(f"Load existing checkpoint from {checkpoint_fn}", flush=True)
self._state = TrainerState.load(checkpoint_fn, default=self._state)