def _init_state()

in imnet_finetune/train.py [0:0]


    def _init_state(self) -> None:
        """
        Initialize the state and load it from an existing checkpoint if any
        """
        torch.manual_seed(0)
        np.random.seed(0)
        print("Create data loaders", flush=True)
        
        Input_size_Image=self._train_cfg.input_size
        
        print("Input size : "+str(Input_size_Image))
        print("Model : " + str(self._train_cfg.architecture) )
        backbone_architecture=None
        if  self._train_cfg.architecture=='PNASNet' :
            backbone_architecture='pnasnet5large'
            
        transformation=get_transforms(input_size=self._train_cfg.input_size,test_size=self._train_cfg.input_size, kind='full', crop=True, need=('train', 'val'), backbone=backbone_architecture)
        transform_test = transformation['val']
        
        
        train_set = datasets.ImageFolder(self._train_cfg.imnet_path+ '/train',transform=transform_test)
        
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_set,num_replicas=self._train_cfg.num_tasks, rank=self._train_cfg.global_rank
        )
        
        self._train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=self._train_cfg.batch_per_gpu,
            num_workers=(self._train_cfg.workers-1),
            sampler=train_sampler,
        )
        test_set = datasets.ImageFolder(self._train_cfg.imnet_path  + '/val',transform=transform_test)


        self._test_loader = torch.utils.data.DataLoader(
            test_set, batch_size=self._train_cfg.batch_per_gpu, shuffle=False, num_workers=(self._train_cfg.workers-1),
        )

        print(f"Total batch_size: {self._train_cfg.batch_per_gpu * self._train_cfg.num_tasks}", flush=True)

        print("Create distributed model", flush=True)
        
        if self._train_cfg.architecture=='PNASNet' :
            model= pnasnet5large(pretrained='imagenet')
            
        if self._train_cfg.architecture=='ResNet50' :
            model=models.resnet50(pretrained=False)
            pretrained_dict=torch.load(self._train_cfg.resnet_weight_path,map_location='cpu')['model']
            model_dict = model.state_dict()
            count=0
            count2=0
            for k in model_dict.keys():
                count=count+1.0
                if(('module.'+k) in pretrained_dict.keys()):
                    count2=count2+1.0
                    model_dict[k]=pretrained_dict.get(('module.'+k))
            model.load_state_dict(model_dict)
            print("load "+str(count2*100/count)+" %")
            
            assert int(count2*100/count)== 100,"model loading error"
            
        if self._train_cfg.architecture=='IGAM_Resnext101_32x48d' :
            model=resnext101_32x48d_wsl(progress=True)

        if self._train_cfg.architecture=='PNASNet' :
            for name, child in model.named_children():
                if 'last_linear' not in name and 'cell_11' not in name and 'cell_10' not in name and 'cell_9' not in name:
                    for name2, params in child.named_parameters():
                        params.requires_grad = False
        elif not self._train_cfg.architecture=='EfficientNet' :
            
            for name, child in model.named_children():
                if 'fc' not in name:
                    for name2, params in child.named_parameters():
                        params.requires_grad = False
    
        if self._train_cfg.architecture=='EfficientNet' :
            assert has_timm
            model = create_model(self._train_cfg.EfficientNet_models,pretrained=False,num_classes=1000) #see https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/efficientnet.py for name
            for name, child in model.named_children():
                if 'classifier' not in name:
                    for name2, params in child.named_parameters():
                        params.requires_grad = False
                        
            pretrained_dict=load_state_dict_from_url(default_cfgs[self._train_cfg.EfficientNet_models]['url'],map_location='cpu')
            model_dict = model.state_dict()
            for k in model_dict.keys():
                if(k in pretrained_dict.keys()):
                    model_dict[k]=pretrained_dict.get(k)
            model.load_state_dict(model_dict)
            torch.cuda.empty_cache()
            model.classifier.requires_grad=True
            model.conv_head.requires_grad=True
            
        model.cuda(self._train_cfg.local_rank)
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[self._train_cfg.local_rank], output_device=self._train_cfg.local_rank
        )
        linear_scaled_lr = 8.0 * self._train_cfg.lr * self._train_cfg.batch_per_gpu * self._train_cfg.num_tasks /512.0
        optimizer = optim.SGD(model.parameters(), lr=linear_scaled_lr, momentum=0.9,weight_decay=1e-4)
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30)
        self._state = TrainerState(
            epoch=0,accuracy=0.0, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler
        )
        checkpoint_fn = osp.join(self._train_cfg.save_folder, str(self._train_cfg.job_id), "checkpoint.pth")
        if os.path.isfile(checkpoint_fn):
            print(f"Load existing checkpoint from {checkpoint_fn}", flush=True)
            self._state = TrainerState.load(checkpoint_fn, default=self._state)