def main()

in classification/train_classifier.py [0:0]


def main(dataset_dir: str,
         cropped_images_dir: str,
         multilabel: bool,
         model_name: str,
         pretrained: bool | str,
         finetune: int,
         label_weighted: bool,
         weight_by_detection_conf: bool | str,
         epochs: int,
         batch_size: int,
         lr: float,
         weight_decay: float,
         num_workers: int,
         logdir: str,
         log_extreme_examples: int,
         seed: Optional[int] = None) -> None:
    """Main function."""
    # input validation
    assert os.path.exists(dataset_dir)
    assert os.path.exists(cropped_images_dir)
    if isinstance(weight_by_detection_conf, str):
        assert os.path.exists(weight_by_detection_conf)
    if isinstance(pretrained, str):
        assert os.path.exists(pretrained)

    # set seed
    seed = np.random.randint(10_000) if seed is None else seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # create logdir and save params
    params = dict(locals())  # make a copy
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')  # '20200722_110816'
    logdir = os.path.join(logdir, timestamp)
    os.makedirs(logdir, exist_ok=True)
    print('Created logdir:', logdir)
    params_json_path = os.path.join(logdir, 'params.json')
    with open(params_json_path, 'w') as f:
        json.dump(params, f, indent=1)

    if 'efficientnet' in model_name:
        img_size = efficientnet.EfficientNet.get_image_size(model_name)
    else:
        img_size = 224

    # create dataloaders and log the index_to_label mapping
    print('Creating dataloaders')
    loaders, label_names = create_dataloaders(
        dataset_csv_path=os.path.join(dataset_dir, 'classification_ds.csv'),
        label_index_json_path=os.path.join(dataset_dir, 'label_index.json'),
        splits_json_path=os.path.join(dataset_dir, 'splits.json'),
        cropped_images_dir=cropped_images_dir,
        img_size=img_size,
        multilabel=multilabel,
        label_weighted=label_weighted,
        weight_by_detection_conf=weight_by_detection_conf,
        batch_size=batch_size,
        num_workers=num_workers,
        augment_train=True)

    writer = tensorboard.SummaryWriter(logdir)

    # create model
    model = build_model(model_name, num_classes=len(label_names),
                        pretrained=pretrained, finetune=finetune > 0)
    model, device = prep_device(model)

    # define loss function and optimizer
    loss_fn: torch.nn.Module
    if multilabel:
        loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none').to(device)
    else:
        loss_fn = torch.nn.CrossEntropyLoss(reduction='none').to(device)

    # using EfficientNet training defaults
    # - batch norm momentum: 0.99
    # - optimizer: RMSProp, decay 0.9 and momentum 0.9
    # - epochs: 350
    # - learning rate: 0.256, decays by 0.97 every 2.4 epochs
    # - weight decay: 1e-5
    optimizer: torch.optim.Optimizer
    if 'efficientnet' in model_name:
        optimizer = torch.optim.RMSprop(model.parameters(), lr, alpha=0.9,
                                        momentum=0.9, weight_decay=weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer, step_size=1, gamma=0.97 ** (1 / 2.4))
    else:  # resnet
        optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9,
                                    weight_decay=weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer, step_size=8, gamma=0.1)  # lower every 8 epochs

    best_epoch_metrics: dict[str, float] = {}
    for epoch in range(epochs):
        print(f'Epoch: {epoch}')
        writer.add_scalar('lr', lr_scheduler.get_last_lr()[0], epoch)

        if epoch > 0 and finetune == epoch:
            print('Turning off fine-tune!')
            set_finetune(model, model_name, finetune=False)

        print('- train:')
        train_metrics, train_heaps, train_cm = run_epoch(
            model, loader=loaders['train'], weighted=False, device=device,
            loss_fn=loss_fn, finetune=finetune > epoch, optimizer=optimizer,
            k_extreme=log_extreme_examples)
        train_metrics = prefix_all_keys(train_metrics, prefix='train/')
        log_run('train', epoch, writer, label_names,
                metrics=train_metrics, heaps=train_heaps, cm=train_cm)
        del train_heaps

        print('- val:')
        val_metrics, val_heaps, val_cm = run_epoch(
            model, loader=loaders['val'], weighted=label_weighted,
            device=device, loss_fn=loss_fn, k_extreme=log_extreme_examples)
        val_metrics = prefix_all_keys(val_metrics, prefix='val/')
        log_run('val', epoch, writer, label_names,
                metrics=val_metrics, heaps=val_heaps, cm=val_cm)
        del val_heaps

        lr_scheduler.step()  # decrease the learning rate

        if val_metrics['val/acc_top1'] > best_epoch_metrics.get('val/acc_top1', 0):  # pylint: disable=line-too-long
            filename = os.path.join(logdir, f'ckpt_{epoch}.pt')
            print(f'New best model! Saving checkpoint to {filename}')
            state = {
                'epoch': epoch,
                'model': getattr(model, 'module', model).state_dict(),
                'val/acc': val_metrics['val/acc_top1'],
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, filename)
            best_epoch_metrics.update(train_metrics)
            best_epoch_metrics.update(val_metrics)
            best_epoch_metrics['epoch'] = epoch

            print('- test:')
            test_metrics, test_heaps, test_cm = run_epoch(
                model, loader=loaders['test'], weighted=label_weighted,
                device=device, loss_fn=loss_fn, k_extreme=log_extreme_examples)
            test_metrics = prefix_all_keys(test_metrics, prefix='test/')
            log_run('test', epoch, writer, label_names,
                    metrics=test_metrics, heaps=test_heaps, cm=test_cm)
            del test_heaps

        # stop training after 8 epochs without improvement
        if epoch >= best_epoch_metrics['epoch'] + 8:
            break

    hparams_dict = {
        'model_name': model_name,
        'multilabel': multilabel,
        'finetune': finetune,
        'batch_size': batch_size,
        'epochs': epochs
    }
    metric_dict = prefix_all_keys(best_epoch_metrics, prefix='hparam/')
    writer.add_hparams(hparam_dict=hparams_dict, metric_dict=metric_dict)
    writer.close()

    # do a complete evaluation run
    best_epoch = best_epoch_metrics['epoch']
    evaluate_model.main(
        params_json_path=params_json_path,
        ckpt_path=os.path.join(logdir, f'ckpt_{best_epoch}.pt'),
        output_dir=logdir, splits=evaluate_model.SPLITS)