in classification/train_classifier.py [0:0]
def main(dataset_dir: str,
cropped_images_dir: str,
multilabel: bool,
model_name: str,
pretrained: bool | str,
finetune: int,
label_weighted: bool,
weight_by_detection_conf: bool | str,
epochs: int,
batch_size: int,
lr: float,
weight_decay: float,
num_workers: int,
logdir: str,
log_extreme_examples: int,
seed: Optional[int] = None) -> None:
"""Main function."""
# input validation
assert os.path.exists(dataset_dir)
assert os.path.exists(cropped_images_dir)
if isinstance(weight_by_detection_conf, str):
assert os.path.exists(weight_by_detection_conf)
if isinstance(pretrained, str):
assert os.path.exists(pretrained)
# set seed
seed = np.random.randint(10_000) if seed is None else seed
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# create logdir and save params
params = dict(locals()) # make a copy
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') # '20200722_110816'
logdir = os.path.join(logdir, timestamp)
os.makedirs(logdir, exist_ok=True)
print('Created logdir:', logdir)
params_json_path = os.path.join(logdir, 'params.json')
with open(params_json_path, 'w') as f:
json.dump(params, f, indent=1)
if 'efficientnet' in model_name:
img_size = efficientnet.EfficientNet.get_image_size(model_name)
else:
img_size = 224
# create dataloaders and log the index_to_label mapping
print('Creating dataloaders')
loaders, label_names = create_dataloaders(
dataset_csv_path=os.path.join(dataset_dir, 'classification_ds.csv'),
label_index_json_path=os.path.join(dataset_dir, 'label_index.json'),
splits_json_path=os.path.join(dataset_dir, 'splits.json'),
cropped_images_dir=cropped_images_dir,
img_size=img_size,
multilabel=multilabel,
label_weighted=label_weighted,
weight_by_detection_conf=weight_by_detection_conf,
batch_size=batch_size,
num_workers=num_workers,
augment_train=True)
writer = tensorboard.SummaryWriter(logdir)
# create model
model = build_model(model_name, num_classes=len(label_names),
pretrained=pretrained, finetune=finetune > 0)
model, device = prep_device(model)
# define loss function and optimizer
loss_fn: torch.nn.Module
if multilabel:
loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none').to(device)
else:
loss_fn = torch.nn.CrossEntropyLoss(reduction='none').to(device)
# using EfficientNet training defaults
# - batch norm momentum: 0.99
# - optimizer: RMSProp, decay 0.9 and momentum 0.9
# - epochs: 350
# - learning rate: 0.256, decays by 0.97 every 2.4 epochs
# - weight decay: 1e-5
optimizer: torch.optim.Optimizer
if 'efficientnet' in model_name:
optimizer = torch.optim.RMSprop(model.parameters(), lr, alpha=0.9,
momentum=0.9, weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer=optimizer, step_size=1, gamma=0.97 ** (1 / 2.4))
else: # resnet
optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9,
weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer=optimizer, step_size=8, gamma=0.1) # lower every 8 epochs
best_epoch_metrics: dict[str, float] = {}
for epoch in range(epochs):
print(f'Epoch: {epoch}')
writer.add_scalar('lr', lr_scheduler.get_last_lr()[0], epoch)
if epoch > 0 and finetune == epoch:
print('Turning off fine-tune!')
set_finetune(model, model_name, finetune=False)
print('- train:')
train_metrics, train_heaps, train_cm = run_epoch(
model, loader=loaders['train'], weighted=False, device=device,
loss_fn=loss_fn, finetune=finetune > epoch, optimizer=optimizer,
k_extreme=log_extreme_examples)
train_metrics = prefix_all_keys(train_metrics, prefix='train/')
log_run('train', epoch, writer, label_names,
metrics=train_metrics, heaps=train_heaps, cm=train_cm)
del train_heaps
print('- val:')
val_metrics, val_heaps, val_cm = run_epoch(
model, loader=loaders['val'], weighted=label_weighted,
device=device, loss_fn=loss_fn, k_extreme=log_extreme_examples)
val_metrics = prefix_all_keys(val_metrics, prefix='val/')
log_run('val', epoch, writer, label_names,
metrics=val_metrics, heaps=val_heaps, cm=val_cm)
del val_heaps
lr_scheduler.step() # decrease the learning rate
if val_metrics['val/acc_top1'] > best_epoch_metrics.get('val/acc_top1', 0): # pylint: disable=line-too-long
filename = os.path.join(logdir, f'ckpt_{epoch}.pt')
print(f'New best model! Saving checkpoint to {filename}')
state = {
'epoch': epoch,
'model': getattr(model, 'module', model).state_dict(),
'val/acc': val_metrics['val/acc_top1'],
'optimizer': optimizer.state_dict()
}
torch.save(state, filename)
best_epoch_metrics.update(train_metrics)
best_epoch_metrics.update(val_metrics)
best_epoch_metrics['epoch'] = epoch
print('- test:')
test_metrics, test_heaps, test_cm = run_epoch(
model, loader=loaders['test'], weighted=label_weighted,
device=device, loss_fn=loss_fn, k_extreme=log_extreme_examples)
test_metrics = prefix_all_keys(test_metrics, prefix='test/')
log_run('test', epoch, writer, label_names,
metrics=test_metrics, heaps=test_heaps, cm=test_cm)
del test_heaps
# stop training after 8 epochs without improvement
if epoch >= best_epoch_metrics['epoch'] + 8:
break
hparams_dict = {
'model_name': model_name,
'multilabel': multilabel,
'finetune': finetune,
'batch_size': batch_size,
'epochs': epochs
}
metric_dict = prefix_all_keys(best_epoch_metrics, prefix='hparam/')
writer.add_hparams(hparam_dict=hparams_dict, metric_dict=metric_dict)
writer.close()
# do a complete evaluation run
best_epoch = best_epoch_metrics['epoch']
evaluate_model.main(
params_json_path=params_json_path,
ckpt_path=os.path.join(logdir, f'ckpt_{best_epoch}.pt'),
output_dir=logdir, splits=evaluate_model.SPLITS)