ASLRecognition/scripts/train.py (124 lines of code) (raw):
'''
USAGE:
python train.py --epochs 10
'''
import pandas as pd
import joblib
import numpy as np
import torch
import random
from PIL import Image
import matplotlib.pyplot as plt
import argparse
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import time
import cnn_models
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
# construct the argument parser and parse the arguments
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epochs', default=10, type=int,
help='number of epochs to train the model for')
args = vars(parser.parse_args())
''' SEED Everything '''
def seed_everything(SEED=42):
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True
SEED=42
seed_everything(SEED=SEED)
''' SEED Everything '''
# set computation device
device = ('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Computation device: {device}")
# read the data.csv file and get the image paths and labels
df = pd.read_csv('data.csv')
X = df.image_path.values
y = df.target.values
(xtrain, xtest, ytrain, ytest) = (train_test_split(X, y,
test_size=0.15, random_state=42))
print(f"Training on {len(xtrain)} images")
print(f"Validationg on {len(xtest)} images")
# image dataset module
class ASLImageDataset(Dataset):
def __init__(self, path, labels):
self.X = path
self.y = labels
# apply augmentations
self.aug = transforms.Compose([
transforms.Resize((224, 224))
])
def __len__(self):
return (len(self.X))
def __getitem__(self, i):
image = Image.open(self.X[i])
image = self.aug(image)
image = np.transpose(image, (2, 0, 1)).astype(np.float32)
label = self.y[i]
return torch.tensor(image, dtype=torch.float), torch.tensor(label, dtype=torch.long)
train_data = ASLImageDataset(xtrain, ytrain)
test_data = ASLImageDataset(xtest, ytest)
# dataloaders
trainloader = DataLoader(train_data, batch_size=32, shuffle=True)
testloader = DataLoader(test_data, batch_size=32, shuffle=False)
# model = models.MobineNetV2(pretrained=True, requires_grad=False)
model = cnn_models.CustomCNN().to(device)
print(model)
# total parameters and trainable parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")
# optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
# loss function
criterion = nn.CrossEntropyLoss()
# training function
def fit(model, dataloader):
print('Training')
model.train()
running_loss = 0.0
running_correct = 0
for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)):
data, target = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(data)
loss = criterion(outputs, target)
running_loss += loss.item()
_, preds = torch.max(outputs.data, 1)
running_correct += (preds == target).sum().item()
loss.backward()
optimizer.step()
train_loss = running_loss/len(dataloader.dataset)
train_accuracy = 100. * running_correct/len(dataloader.dataset)
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}")
return train_loss, train_accuracy
#validation function
def validate(model, dataloader):
print('Validating')
model.eval()
running_loss = 0.0
running_correct = 0
with torch.no_grad():
for i, data in tqdm(enumerate(dataloader), total=int(len(test_data)/dataloader.batch_size)):
data, target = data[0].to(device), data[1].to(device)
outputs = model(data)
loss = criterion(outputs, target)
running_loss += loss.item()
_, preds = torch.max(outputs.data, 1)
running_correct += (preds == target).sum().item()
val_loss = running_loss/len(dataloader.dataset)
val_accuracy = 100. * running_correct/len(dataloader.dataset)
print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}')
return val_loss, val_accuracy
train_loss , train_accuracy = [], []
val_loss , val_accuracy = [], []
start = time.time()
for epoch in range(args['epochs']):
print(f"Epoch {epoch+1} of {args['epochs']}")
train_epoch_loss, train_epoch_accuracy = fit(model, trainloader)
val_epoch_loss, val_epoch_accuracy = validate(model, testloader)
train_loss.append(train_epoch_loss)
train_accuracy.append(train_epoch_accuracy)
val_loss.append(val_epoch_loss)
val_accuracy.append(val_epoch_accuracy)
print("loss: {val_epoch_loss}, accuracy: {val_epoch_accuracy}")
end = time.time()
print('Saving model...')
torch.save(model.state_dict(), 'asl.pth')