def run()

in benchmark/supervised/generate_datasets.py [0:0]


def run(config):
    version = config['version']
    for dataset_name, dconf in config['datasets'].items():
        cprint("[%s]\n" % dataset_name, 'yellow')

        # conf
        x_key = dconf['x_key']
        y_key = dconf['y_key']
        splits = dconf['splits']
        train_classes = dconf['train_classes']
        test_classes = dconf['test_classes']
        img_size = dconf['shape'][0]
        index_shots = dconf['index_shots']
        query_shots = dconf['query_shots']

        # download and merge splits
        x = []
        y = []
        for split in splits:
            ds, ds_info = tfds.load(dataset_name, split=split, with_info=True)

            if x_key not in ds_info.features:
                raise ValueError("x_key not found - available features are:",
                                 str(ds_info.features.keys()))
            if y_key not in ds_info.features:
                raise ValueError("y_key not found - available features are:",
                                 str(ds_info.features.keys()))

            pb = tqdm(total=ds_info.splits[split].num_examples,
                      desc="Merging %s" % split)

            for e in ds:
                x.append(e[x_key])
                y.append(int(e[y_key]))
                pb.update()
            pb.close()

    cprint("|-Resize", 'blue')
    x_resized = []
    for e in tqdm(x, desc="resizing"):
        x_resized.append(resize(e, img_size))

    cprint("|-Partition", 'green')
    ds_index = defaultdict(list)
    ds_query = defaultdict(list)
    x_train = []
    y_train = []
    x_test = []
    y_test = []
    train_cls = list(range(train_classes[0], train_classes[1]))
    test_cls = list(range(test_classes[0], test_classes[1]))

    for idx, e in enumerate(x_resized):
        cl = y[idx]
        if len(ds_index[cl]) < index_shots:
            ds_index[cl].append(e)
        elif len(ds_query[cl]) < query_shots:
            ds_query[cl].append(e)
        else:
            if cl in train_cls:
                x_train.append(e)
                y_train.append(cl)
            else:
                x_test.append(e)
                y_test.append(cl)

    # flatten the index
    x_index = []
    y_index = []
    for k, es in ds_index.items():
        y_index.extend([k] * len(es))
        x_index.extend(es)

    # flatten query indexes
    x_unseen_queries = []
    y_unseen_queries = []
    x_seen_queries = []
    y_seen_queries = []
    for k, es in ds_query.items():
        if k in train_cls:
            y_seen_queries.extend([k] * len(es))
            x_seen_queries.extend(es)
        else:
            y_unseen_queries.extend([k] * len(es))
            x_unseen_queries.extend(es)

    # sanity checks
    assert len(y_unseen_queries) == len(x_unseen_queries)
    assert len(y_seen_queries) == len(x_seen_queries)
    assert len(y_index) == len(x_index)
    assert len(x_train) == len(y_train)
    assert len(x_test) == len(y_test)
    for lst in ds_index.values():
        assert len(lst) == index_shots
    for lst in ds_query.values():
        assert len(lst) == query_shots

    print(" |-train", len(x_train), len(y_train))
    print(" |-test", len(x_test), len(y_test))
    print(" |-index", len(x_index), len(y_index))
    print(" |-query seen", len(x_seen_queries), len(y_seen_queries))
    print(" |-query unseen", len(x_unseen_queries), len(y_unseen_queries))

    # save
    fpath = "datasets/%s/%s/" % (version, dataset_name)
    cprint("Save files in %s" % fpath, "blue")
    clean_dir(fpath)

    files = [['train', x_train, y_train],
             ['test', x_test, y_test],
             ['index', x_index, y_index],
             ['unseen_queries', x_unseen_queries, y_unseen_queries],
             ['seen_queries', x_seen_queries, y_seen_queries]]
    for f in files:
        cprint("|-saving %s" % f[0], 'magenta')
        fname = "%s%s.npz" % (fpath, f[0])
        np.savez(fname, x=f[1], y=f[2])
    info = {
        "dataset": dataset_name,
        "splits": splits,
        "img_size": img_size,
        "num_classes": len(train_cls) + len(test_cls),
        "num_train_classes": len(train_cls),
        "num_test_classes": len(test_cls),
        "train_classes": train_cls,
        "test_classes": test_cls,
        "index_shots": index_shots,
        "query_shots": query_shots,
        "data": {
            "train": len(x_train),
            "test": len(x_test),
            "index": len(x_index),
            "unseen_queries": len(y_unseen_queries),
            "seen_queries": len(x_seen_queries)
        }
    }
    with open("%sinfo.json" % fpath, 'w') as o:
        o.write(json.dumps(info))