def main()

in hype/hypernymy_eval.py [0:0]


def main(chkpnt, cpu=False):
    download_data()
    extra_args = {'map_location' : 'cpu'} if cpu else {}
    if isinstance(chkpnt, str):
        assert os.path.exists(chkpnt)
        chkpnt = th.load(chkpnt, **extra_args)

    model = EntailmentConeModel(chkpnt)

    # perform the evaluations
    with th.no_grad():
        results = all_evaluations(model)

    results['epoch'] = chkpnt['epoch']

    def iter(d, res, path, sum, count):
        if isinstance(d, dict):
            for k in d.keys():
                sum, count = iter(d[k], res, path + '_' + k, sum, count)
        elif 'val_inv' in path and 'ap100' not in path:
            res[path[1:]] = d  # strip leading tab
            sum += d
            count += 1
        return sum, count
    summary = {}
    sum, count = iter(results, summary, '', 0, 0)
    summary['eval_hypernymy_avg'] = sum / count
    return results, summary