def _train()

in imnet_finetune/train.py [0:0]


    def _train(self) -> Optional[float]:
        criterion = nn.CrossEntropyLoss()
        print_freq = 10
        acc = None
        max_accuracy=0.0
        
        print("Evaluation before fine-tuning")        
        correct = 0
        total = 0
        count=0.0
        running_val_loss = 0.0
        self._state.model.eval()
        
        if self._train_cfg.architecture=='PNASNet' :
            self._state.model.module.cell_11.eval()
            self._state.model.module.cell_10.eval()
            self._state.model.module.cell_9.eval()
            self._state.model.module.dropout.eval()
        elif self._train_cfg.architecture=='EfficientNet' :
            self._state.model.module.classifier.eval()
            self._state.model.module.conv_head.eval()
            self._state.model.module.bn2.eval()
            
        else:
            self._state.model.module.layer4[2].bn3.eval()
            
            
        with torch.no_grad():
            for data in self._test_loader:
                images, labels = data
                images = images.cuda(self._train_cfg.local_rank, non_blocking=True)
                labels = labels.cuda(self._train_cfg.local_rank, non_blocking=True)
                outputs = self._state.model(images)
                loss_val = criterion(outputs, labels)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                running_val_loss += loss_val.item()
                count=count+1.0

        acc = correct / total
        ls_nm=running_val_loss/count
        print(f"Accuracy of the network on the 50000 test images: {acc:.1%}", flush=True)
        print(f"Loss of the network on the 50000 test images: {ls_nm:.3f}", flush=True)
        print("Accuracy before fine-tuning : "+str(acc))
        max_accuracy=np.max((max_accuracy,acc))
        start_epoch = self._state.epoch
        # Start from the loaded epoch
        for epoch in range(start_epoch, self._train_cfg.epochs):
            print(f"Start epoch {epoch}", flush=True)
            self._state.model.eval()
            if self._train_cfg.architecture=='PNASNet' :
                self._state.model.module.cell_11.train()
                self._state.model.module.cell_10.train()
                self._state.model.module.cell_9.train()
                self._state.model.module.dropout.train()
            elif self._train_cfg.architecture=='EfficientNet' :
                self._state.model.module.classifier.train()
                self._state.model.module.conv_head.train()
                self._state.model.module.bn2.train()
            else:
                self._state.model.module.layer4[2].bn3.train()
                
                
            self._state.lr_scheduler.step(epoch)
            self._state.epoch = epoch
            running_loss = 0.0
            for i, data in enumerate(self._train_loader):
                inputs, labels = data
                inputs = inputs.cuda(self._train_cfg.local_rank, non_blocking=True)
                labels = labels.cuda(self._train_cfg.local_rank, non_blocking=True)

                outputs = self._state.model(inputs)
                loss = criterion(outputs, labels)

                self._state.optimizer.zero_grad()
                loss.backward()
                self._state.optimizer.step()

                running_loss += loss.item()
                if i % print_freq == print_freq - 1:
                    print(f"[{epoch:02d}, {i:05d}] loss: {running_loss/print_freq:.3f}", flush=True)
                    running_loss = 0.0
                
                
            if epoch==self._train_cfg.epochs-1:
                print("Start evaluation of the model", flush=True)
                
                correct = 0
                total = 0
                count=0.0
                running_val_loss = 0.0
                self._state.model.eval()
                
                if self._train_cfg.architecture=='PNASNet' :
                    self._state.model.module.cell_11.eval()
                    self._state.model.module.cell_10.eval()
                    self._state.model.module.cell_9.eval()
                    self._state.model.module.dropout.eval()
                elif self._train_cfg.architecture=='EfficientNet' :
                    self._state.model.module.classifier.eval()
                    self._state.model.module.conv_head.eval()
                    self._state.model.module.bn2.eval()
                else:
                    self._state.model.module.layer4[2].bn3.eval()
                    
                with torch.no_grad():
                    for data in self._test_loader:
                        images, labels = data
                        images = images.cuda(self._train_cfg.local_rank, non_blocking=True)
                        labels = labels.cuda(self._train_cfg.local_rank, non_blocking=True)
                        outputs = self._state.model(images)
                        loss_val = criterion(outputs, labels)
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).sum().item()
                        running_val_loss += loss_val.item()
                        count=count+1.0

                acc = correct / total
                ls_nm=running_val_loss/count
                print(f"Accuracy of the network on the 50000 test images: {acc:.1%}", flush=True)
                print(f"Loss of the network on the 50000 test images: {ls_nm:.3f}", flush=True)
                self._state.accuracy = acc
                if self._train_cfg.global_rank == 0:
                    self.checkpoint(rm_init=False)
                if epoch==self._train_cfg.epochs-1:
                    return acc