in imnet_resnet50_scratch/train.py [0:0]
def _train(self) -> Optional[float]:
criterion = nn.CrossEntropyLoss()
print_freq = 10
acc = None
max_accuracy=0.0
# Start from the loaded epoch
start_epoch = self._state.epoch
for epoch in range(start_epoch, self._train_cfg.epochs):
print(f"Start epoch {epoch}", flush=True)
self._state.model.train()
self._state.lr_scheduler.step(epoch)
self._state.epoch = epoch
running_loss = 0.0
count=0
for i, data in enumerate(self._train_loader):
inputs, labels = data
inputs = inputs.cuda(self._train_cfg.local_rank, non_blocking=True)
labels = labels.cuda(self._train_cfg.local_rank, non_blocking=True)
outputs = self._state.model(inputs)
loss = criterion(outputs, labels)
self._state.optimizer.zero_grad()
loss.backward()
self._state.optimizer.step()
running_loss += loss.item()
count=count+1
if i % print_freq == print_freq - 1:
print(f"[{epoch:02d}, {i:05d}] loss: {running_loss/print_freq:.3f}", flush=True)
running_loss = 0.0
if count>=5005 * 512 /(self._train_cfg.batch_per_gpu * self._train_cfg.num_tasks):
break
if epoch==self._train_cfg.epochs-1:
print("Start evaluation of the model", flush=True)
correct = 0
total = 0
count=0.0
running_val_loss = 0.0
self._state.model.eval()
with torch.no_grad():
for data in self._test_loader:
images, labels = data
images = images.cuda(self._train_cfg.local_rank, non_blocking=True)
labels = labels.cuda(self._train_cfg.local_rank, non_blocking=True)
outputs = self._state.model(images)
loss_val = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
running_val_loss += loss_val.item()
count=count+1.0
acc = correct / total
ls_nm=running_val_loss/count
print(f"Accuracy of the network on the 50000 test images: {acc:.1%}", flush=True)
print(f"Loss of the network on the 50000 test images: {ls_nm:.3f}", flush=True)
self._state.accuracy = acc
if self._train_cfg.global_rank == 0:
self.checkpoint(rm_init=False)
print("accuracy val epoch "+str(epoch)+" acc= "+str(acc))
max_accuracy=np.max((max_accuracy,acc))
if epoch==self._train_cfg.epochs-1:
return acc