ASLRecognition/scripts/train.py (124 lines of code) (raw):

''' USAGE: python train.py --epochs 10 ''' import pandas as pd import joblib import numpy as np import torch import random from PIL import Image import matplotlib.pyplot as plt import argparse import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision.transforms as transforms import time import cnn_models from tqdm import tqdm from sklearn.model_selection import train_test_split from torch.utils.data import Dataset, DataLoader # construct the argument parser and parse the arguments parser = argparse.ArgumentParser() parser.add_argument('-e', '--epochs', default=10, type=int, help='number of epochs to train the model for') args = vars(parser.parse_args()) ''' SEED Everything ''' def seed_everything(SEED=42): random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) torch.cuda.manual_seed_all(SEED) torch.backends.cudnn.benchmark = True SEED=42 seed_everything(SEED=SEED) ''' SEED Everything ''' # set computation device device = ('cuda:0' if torch.cuda.is_available() else 'cpu') print(f"Computation device: {device}") # read the data.csv file and get the image paths and labels df = pd.read_csv('data.csv') X = df.image_path.values y = df.target.values (xtrain, xtest, ytrain, ytest) = (train_test_split(X, y, test_size=0.15, random_state=42)) print(f"Training on {len(xtrain)} images") print(f"Validationg on {len(xtest)} images") # image dataset module class ASLImageDataset(Dataset): def __init__(self, path, labels): self.X = path self.y = labels # apply augmentations self.aug = transforms.Compose([ transforms.Resize((224, 224)) ]) def __len__(self): return (len(self.X)) def __getitem__(self, i): image = Image.open(self.X[i]) image = self.aug(image) image = np.transpose(image, (2, 0, 1)).astype(np.float32) label = self.y[i] return torch.tensor(image, dtype=torch.float), torch.tensor(label, dtype=torch.long) train_data = ASLImageDataset(xtrain, ytrain) test_data = ASLImageDataset(xtest, ytest) # dataloaders trainloader = DataLoader(train_data, batch_size=32, shuffle=True) testloader = DataLoader(test_data, batch_size=32, shuffle=False) # model = models.MobineNetV2(pretrained=True, requires_grad=False) model = cnn_models.CustomCNN().to(device) print(model) # total parameters and trainable parameters total_params = sum(p.numel() for p in model.parameters()) print(f"{total_params:,} total parameters.") total_trainable_params = sum( p.numel() for p in model.parameters() if p.requires_grad) print(f"{total_trainable_params:,} training parameters.") # optimizer optimizer = optim.Adam(model.parameters(), lr=0.001) # loss function criterion = nn.CrossEntropyLoss() # training function def fit(model, dataloader): print('Training') model.train() running_loss = 0.0 running_correct = 0 for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)): data, target = data[0].to(device), data[1].to(device) optimizer.zero_grad() outputs = model(data) loss = criterion(outputs, target) running_loss += loss.item() _, preds = torch.max(outputs.data, 1) running_correct += (preds == target).sum().item() loss.backward() optimizer.step() train_loss = running_loss/len(dataloader.dataset) train_accuracy = 100. * running_correct/len(dataloader.dataset) print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}") return train_loss, train_accuracy #validation function def validate(model, dataloader): print('Validating') model.eval() running_loss = 0.0 running_correct = 0 with torch.no_grad(): for i, data in tqdm(enumerate(dataloader), total=int(len(test_data)/dataloader.batch_size)): data, target = data[0].to(device), data[1].to(device) outputs = model(data) loss = criterion(outputs, target) running_loss += loss.item() _, preds = torch.max(outputs.data, 1) running_correct += (preds == target).sum().item() val_loss = running_loss/len(dataloader.dataset) val_accuracy = 100. * running_correct/len(dataloader.dataset) print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}') return val_loss, val_accuracy train_loss , train_accuracy = [], [] val_loss , val_accuracy = [], [] start = time.time() for epoch in range(args['epochs']): print(f"Epoch {epoch+1} of {args['epochs']}") train_epoch_loss, train_epoch_accuracy = fit(model, trainloader) val_epoch_loss, val_epoch_accuracy = validate(model, testloader) train_loss.append(train_epoch_loss) train_accuracy.append(train_epoch_accuracy) val_loss.append(val_epoch_loss) val_accuracy.append(val_epoch_accuracy) print("loss: {val_epoch_loss}, accuracy: {val_epoch_accuracy}") end = time.time() print('Saving model...') torch.save(model.state_dict(), 'asl.pth')