def main()

in geospatial/preprocessing/create_label_masks.py [0:0]


def main():
    parser = argparse.ArgumentParser(
        description='Create masks for each label json file for disasters specified at the top of the script.')
    parser.add_argument(
        'root_dir',
        help=('Path to the directory that contains both the `images` and `labels` folders. '
              'The `targets_border{border_width}` folder will be created if it does not already exist.')
    )
    parser.add_argument(
        '-b', '--border_width',
        type=int,
        default=1
    )
    parser.add_argument(
        '-o', '--overwrite_target',
        help='flag if we want to generate all targets anew',
        action='store_true'
    )
    args = parser.parse_args()

    images_dir = os.path.join(args.root_dir, 'images')
    labels_dir = os.path.join(args.root_dir, 'labels')

    assert os.path.exists(args.root_dir), 'root_dir does not exist'
    assert os.path.isdir(args.root_dir), 'root_dir needs to be path to a directory'
    assert os.path.exists(images_dir), 'root_dir does not contain the folder `images`'
    assert os.path.exists(labels_dir), 'root_dir does not contain the folder `labels`'
    assert args.border_width >= 0, 'border_width < 0'
    assert args.border_width < 5, 'specified border_width is > 4 pixels - are you sure?'

    assert isinstance(DISASTERS_OF_INTEREST, tuple)
    for i in DISASTERS_OF_INTEREST:
        assert i.endswith('_')

    print(f'Disasters to create the masks for: {DISASTERS_OF_INTEREST}')

    targets_dir = os.path.join(args.root_dir, f'targets_border{args.border_width}')
    print(f'A targets directory is at {targets_dir}')
    os.makedirs(targets_dir, exist_ok=True)

    # list out label files for the disaster of interest
    li_label_fn = os.listdir(labels_dir)
    li_label_fn = sorted([i for i in li_label_fn if i.endswith('.json')])
    li_label_paths = [os.path.join(labels_dir, i) for i in li_label_fn if i.startswith(DISASTERS_OF_INTEREST)]

    print(f'{len(li_label_fn)} label jsons found in labels_dir, '
          f'{len(li_label_paths)} are for the disasters of interest.')

    mask_tiles(images_dir, li_label_paths, targets_dir, args.border_width, args.overwrite_target)
    print('Done!')