def merge()

in misc/merge_iNat_and_animals_extension.py [0:0]


def merge(input1, input2, output, old_new_mapping = [{}, {}]):
    print('mapping', len(old_new_mapping[0]), len(old_new_mapping[1]))
    # input 1 is the inat json
    # input 2 is the extension-only json
    # {659: 1221, 216: 375}
    input1_class_blacklist = [1221, 375]
    # First json
    with open(input1, 'rt') as fi:
        js1 = json.load(fi)
    # Second json
    with open(input2, 'rt') as fi:
        js2 = json.load(fi)

    # Delete duplicate classes from input1
    images_to_delete = set()
    for ann_idx in range(len(js1['annotations'])):
        if js1['annotations'][ann_idx]['category_id'] in input1_class_blacklist:
            images_to_delete.add(js1['annotations'][ann_idx]['image_id'])
    for k,v in js1.items():
        print(k,len(v))
    print('Going to delete {} images due to duplicate classes'.format(len(images_to_delete)))
    js1['images'] = [im for im in js1['images'] if im['id'] not in images_to_delete]
    js1['categories'] = [cat for cat in js1['categories'] if cat['id'] not in input1_class_blacklist]
    js1['annotations'] = [ann for ann in js1['annotations'] if ann['image_id'] not in images_to_delete]
    for k,v in js1.items():
        print(k,len(v))

    # Renumber classes in input1
    max_class_id = -1
    for old_id in set([ann['category_id'] for ann in js1['annotations']]):
        if old_id not in old_new_mapping[0].keys():
            old_new_mapping[0][old_id] = max_class_id + 1
            max_class_id += 1
    for cat_idx in range(len(js1['categories'])):
        js1['categories'][cat_idx]['id'] = old_new_mapping[0][js1['categories'][cat_idx]['id']]
    for ann_idx in range(len(js1['annotations'])):
        js1['annotations'][ann_idx]['category_id'] = old_new_mapping[0][js1['annotations'][ann_idx]['category_id']]

    # Renumber classes in input2
    max_class_id = max([cat['id'] for cat in js1['categories']])
    for new_id, old_id in enumerate(list(set([ann['category_id'] for ann in js2['annotations']]))):
        if old_id not in old_new_mapping[1].keys():
            old_new_mapping[1][old_id] = max_class_id + 1
            max_class_id += 1
    #assert len(set([cat['id'] for cat in js1['categories']]) & set(old_new_mapping[1].values())) == 0
    import ipdb; ipdb.set_trace()
    js2['categories'] = [cat for cat in js2['categories'] if cat['id'] in old_new_mapping[1].keys()]
    for cat_idx in range(len(js2['categories'])):
        js2['categories'][cat_idx]['id'] = old_new_mapping[1][js2['categories'][cat_idx]['id']]
    for ann_idx in range(len(js2['annotations'])):
        js2['annotations'][ann_idx]['category_id'] = old_new_mapping[1][js2['annotations'][ann_idx]['category_id']]
    import ipdb; ipdb.set_trace()

    js1['images'] += js2['images']
    js1['annotations'] += js2['annotations']
    js1['categories'] += js2['categories']
    # Write out
    with open(output, 'wt') as fi:
        json.dump(js1, open(output, 'wt'))
    print('mapping', len(old_new_mapping[0]), len(old_new_mapping[1]))
    return old_new_mapping