in ASLRecognition/scripts/train.py [0:0]
def fit(model, dataloader):
print('Training')
model.train()
running_loss = 0.0
running_correct = 0
for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)):
data, target = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(data)
loss = criterion(outputs, target)
running_loss += loss.item()
_, preds = torch.max(outputs.data, 1)
running_correct += (preds == target).sum().item()
loss.backward()
optimizer.step()
train_loss = running_loss/len(dataloader.dataset)
train_accuracy = 100. * running_correct/len(dataloader.dataset)
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}")
return train_loss, train_accuracy