def _init_state()

in imnet_extract/train.py [0:0]


    def _init_state(self) -> None:
        """
        Initialize the state and load it from an existing checkpoint if any
        """
        torch.manual_seed(0)
        np.random.seed(0)
        
        print("Create data loaders", flush=True)
        print("Input size : "+str(self._train_cfg.input_size))
        print("Model : " + str(self._train_cfg.architecture) )
        backbone_architecture=None
        if  self._train_cfg.architecture=='PNASNet' :
            backbone_architecture='pnasnet5large'
            
        
        transformation=get_transforms(input_size=self._train_cfg.input_size,test_size=self._train_cfg.input_size, kind='full', crop=True, need=('train', 'val'), backbone=backbone_architecture)
        transform_test = transformation['val']

        test_set = datasets.ImageFolder(self._train_cfg.dataset_path,transform=transform_test)

        self._test_loader = torch.utils.data.DataLoader(
            test_set, batch_size=self._train_cfg.batch_per_gpu, shuffle=False, num_workers=(self._train_cfg.workers-1),
        )


        print("Create distributed model", flush=True)
        
        if self._train_cfg.architecture=='PNASNet' :
            model= pnasnet5large(pretrained='imagenet')
            
        if self._train_cfg.architecture=='ResNet50' :
            model=resnet50(pretrained=False)
            
        if self._train_cfg.architecture=='IGAM_Resnext101_32x48d' :
            model=resnext101_32x48d_wsl(progress=True)

        pretrained_dict=torch.load(self._train_cfg.weight_path,map_location='cpu')['model']
        model_dict = model.state_dict()
        count=0
        count2=0
        for k in model_dict.keys():
            count=count+1.0
            if(('module.'+k) in pretrained_dict.keys()):
                count2=count2+1.0
                model_dict[k]=pretrained_dict.get(('module.'+k))
        model.load_state_dict(model_dict)
        print("load "+str(count2*100/count)+" %")
        
        assert int(count2*100/count)== 100,"model loading error"
        
        for name, child in model.named_children():
            for name2, params in child.named_parameters():
                params.requires_grad = False
    
        print('model_load')
        if torch.cuda.is_available():
            
            model.cuda(self._train_cfg.local_rank)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[self._train_cfg.local_rank], output_device=self._train_cfg.local_rank
            )
        
        self._state = TrainerState(
             model=model
        )
        checkpoint_fn = osp.join(self._train_cfg.save_folder, str(self._train_cfg.job_id), "checkpoint.pth")
        if os.path.isfile(checkpoint_fn):
            print(f"Load existing checkpoint from {checkpoint_fn}", flush=True)
            self._state = TrainerState.load(checkpoint_fn, default=self._state)