train-cnn/Container/scripts/multi-train.py (76 lines of code) (raw):

import torch import torchvision import os ''' This model is based on the Cats & Dogs classifier from TirendazAcademy on Kaggle: https://www.kaggle.com/code/tirendazacademy/cats-dogs-classification-with-pytorch?scriptVersionId=127290261&cellId=41 slightly modified in the first layer owing to our preprocessing the training images to use only a single color channel instead of three. ''' class ImageClassifier(torch.nn.Module): def __init__(self): super().__init__() self.conv_layer_1 = torch.nn.Sequential( torch.nn.Conv2d(1, 64, 3, padding=1), torch.nn.ReLU(), torch.nn.BatchNorm2d(64), torch.nn.MaxPool2d(2)) self.conv_layer_2 = torch.nn.Sequential( torch.nn.Conv2d(64, 512, 3, padding=1), torch.nn.ReLU(), torch.nn.BatchNorm2d(512), torch.nn.MaxPool2d(2)) self.conv_layer_3 = torch.nn.Sequential( torch.nn.Conv2d(512, 512, kernel_size=3, padding=1), torch.nn.ReLU(), torch.nn.BatchNorm2d(512), torch.nn.MaxPool2d(2)) self.classifier = torch.nn.Sequential( torch.nn.Flatten(), torch.nn.Linear(in_features=2048, out_features=2)) def forward(self, x: torch.Tensor): x = self.conv_layer_1(x) x = self.conv_layer_2(x) for _ in range(4): x = self.conv_layer_3(x) x = self.classifier(x) return x if __name__ == "__main__": global_rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) torch.distributed.init_process_group(backend="nccl") torch.cuda.set_device(local_rank) train_set = torchvision.datasets.ImageFolder(root="/tmp/train", transform=torchvision.transforms.Compose([torchvision.transforms.Grayscale(), torchvision.transforms.ToTensor()]), target_transform=None) model = ImageClassifier() model.to(local_rank) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) train_data = torch.utils.data.DataLoader( train_set, batch_size=32, pin_memory=True, shuffle=False, sampler=torch.utils.data.distributed.DistributedSampler(train_set) ) for epoch in range(50): total_loss = 0.0 num_batches = 0 batch_size = len(next(iter(train_data))[0]) train_data.sampler.set_epoch(epoch) for source, targets in train_data: source = source.to(local_rank) targets = targets.to(local_rank) optimizer.zero_grad() output = model(source) loss = torch.nn.functional.cross_entropy(output, targets) torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) loss /= torch.distributed.get_world_size() total_loss += loss.item() num_batches += 1 loss.backward() optimizer.step() average_loss = total_loss / num_batches print(f"[Node{global_rank}] Epoch {epoch} | Batchsize: {batch_size} | Steps: {len(train_data)} | Average Loss: {average_loss:.4f}") torch.distributed.destroy_process_group() if global_rank == 0: torch.save(model.state_dict(), os.path.join(os.environ["DATA_ROOT"], "final.pt"))