def search_topic()

in fulltext_search/search_sharded.py [0:0]


def search_topic(sample):
    top_category = sample["top_category"][0].strip()
    subcategory = sample["subcategory"][0].strip()
    subtopic = sample["subtopic"][0].strip()
    for c in ["!", '"', "$", "'", "(", ")", "/", "<", "@", "\\", "^", "|", "~"]:
        top_category = top_category.replace(c, " ")
        subcategory = subcategory.replace(c, " ")
        subtopic = subtopic.replace(c, " ")
    # boosting the IDF score of subtopic tokens
    boosted_subtopic = " ".join([w + "^2" for w in subtopic.split()])
    match_query = " ".join([top_category, subcategory, subtopic])
    boosted_query = " ".join([top_category, subcategory, boosted_subtopic])

    boosted_hits = run_query({"query_string": boosted_query}, args.n_pages)
    print(f"Boosted hits: {len(boosted_hits)} for {boosted_query}", file=sys.stderr)
    if len(boosted_hits) < args.n_pages:
        match_hits = run_query(
            {"match": {"content": match_query}}, args.n_pages + len(boosted_hits)
        )
        print(f"Match hits: {len(match_hits)} for {match_query}", file=sys.stderr)
    else:
        match_hits = []

    hit_ids = set()
    hits = []
    for hit in boosted_hits + match_hits:
        if hit["_id"] not in hit_ids:
            hits.append(hit)
            hit_ids.add(hit["_id"])
    hits = hits[: args.n_pages]

    results = {
        "top_category": sample["top_category"] * len(hits),
        "subcategory": sample["subcategory"] * len(hits),
        "subtopic": sample["subtopic"] * len(hits),
        "topic_hits": hits,
        "num_hits": [len(hits)] * len(hits),
    }
    return results