def _train()

in imnet_resnet50_scratch/train.py [0:0]


    def _train(self) -> Optional[float]:
        criterion = nn.CrossEntropyLoss()
        print_freq = 10
        acc = None
        max_accuracy=0.0
        # Start from the loaded epoch
        start_epoch = self._state.epoch
        for epoch in range(start_epoch, self._train_cfg.epochs):
            print(f"Start epoch {epoch}", flush=True)
            self._state.model.train()
            self._state.lr_scheduler.step(epoch)
            self._state.epoch = epoch
            running_loss = 0.0
            count=0
            for i, data in enumerate(self._train_loader):
                inputs, labels = data
                inputs = inputs.cuda(self._train_cfg.local_rank, non_blocking=True)
                labels = labels.cuda(self._train_cfg.local_rank, non_blocking=True)

                outputs = self._state.model(inputs)
                loss = criterion(outputs, labels)

                self._state.optimizer.zero_grad()
                loss.backward()
                self._state.optimizer.step()

                running_loss += loss.item()
                count=count+1
                if i % print_freq == print_freq - 1:
                    print(f"[{epoch:02d}, {i:05d}] loss: {running_loss/print_freq:.3f}", flush=True)
                    running_loss = 0.0
                if count>=5005 * 512 /(self._train_cfg.batch_per_gpu * self._train_cfg.num_tasks):
                    break
                
            if epoch==self._train_cfg.epochs-1:
                print("Start evaluation of the model", flush=True)
                
                correct = 0
                total = 0
                count=0.0
                running_val_loss = 0.0
                self._state.model.eval()
                with torch.no_grad():
                    for data in self._test_loader:
                        images, labels = data
                        images = images.cuda(self._train_cfg.local_rank, non_blocking=True)
                        labels = labels.cuda(self._train_cfg.local_rank, non_blocking=True)
                        outputs = self._state.model(images)
                        loss_val = criterion(outputs, labels)
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).sum().item()
                        running_val_loss += loss_val.item()
                        count=count+1.0

                acc = correct / total
                ls_nm=running_val_loss/count
                print(f"Accuracy of the network on the 50000 test images: {acc:.1%}", flush=True)
                print(f"Loss of the network on the 50000 test images: {ls_nm:.3f}", flush=True)
                self._state.accuracy = acc
                if self._train_cfg.global_rank == 0:
                    self.checkpoint(rm_init=False)
                print("accuracy val epoch "+str(epoch)+" acc= "+str(acc))
                max_accuracy=np.max((max_accuracy,acc))
                if epoch==self._train_cfg.epochs-1:
                    return acc