in community-content/pytorch_image_classification_single_gpu_with_vertex_sdk_and_torchserve/trainer/task.py [0:0]
def main():
args = parse_args()
local_data_dir = './tmp/data'
local_model_dir = './tmp/model'
local_tensorboard_log_dir = './tmp/logs'
model_dir = args.model_dir or local_model_dir
tensorboard_log_dir = args.tensorboard_log_dir or local_tensorboard_log_dir
gs_prefix = 'gs://'
gcsfuse_prefix = '/gcs/'
if model_dir and model_dir.startswith(gs_prefix):
model_dir = model_dir.replace(gs_prefix, gcsfuse_prefix)
if tensorboard_log_dir and tensorboard_log_dir.startswith(gs_prefix):
tensorboard_log_dir = tensorboard_log_dir.replace(gs_prefix, gcsfuse_prefix)
makedirs(model_dir)
writer = SummaryWriter(tensorboard_log_dir)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')
data_dir = download_data(local_data_dir)
image_datasets, class_names = load_dataset(data_dir)
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
print(f'Dataset sizes: {dataset_sizes}')
dataloaders = {
x: torch.utils.data.DataLoader(
image_datasets[x],
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers
)
for x in ['train', 'val']
}
model_ft = load_model(class_names, device)
criterion = torch.nn.CrossEntropyLoss()
# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(
model_ft.parameters(), lr=args.learning_rate, momentum=args.momentum)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = optim.lr_scheduler.StepLR(
optimizer_ft, step_size=7, gamma=0.1)
model = train(
model=model_ft,
criterion=criterion,
optimizer=optimizer_ft,
scheduler=exp_lr_scheduler,
dataset_sizes=dataset_sizes,
dataloaders=dataloaders,
device=device,
epochs=args.epochs,
writer=writer,
)
model_name = 'antandbee.pth'
model_path = os.path.join(model_dir, f'{model_name}')
torch.save(model.state_dict(), model_path)
print(f'Model is saved to {model_dir}')
print(f'Tensorboard logs are saved to: {tensorboard_log_dir}')
writer.close()
return