so_vector/_tools/parse_embed.py (60 lines of code) (raw):

import argparse import json import re import xml.sax from pathlib import Path from bs4 import BeautifulSoup from sentence_transformers import SentenceTransformer model = "sentence-transformers/multi-qa-mpnet-base-cos-v1" embedding_model = SentenceTransformer(model) thedatawriter = None class PostsHandler(xml.sax.ContentHandler): def startElement(self, name, attrs): if name == "row": postType = int(attrs["PostTypeId"]) # only keep questions, dropping the answer posts if postType == 1: record = {} # In some questions e.g. 10030718 the ownerID is missing and we have OwnerDisplayName instead ownerDisplayName = "" ownerId = "" user = "" if "OwnerUserId" in attrs: ownerId = attrs["OwnerUserId"] record["user"] = ownerId elif "OwnerDisplayName" in attrs: ownerDisplayName = attrs["OwnerDisplayName"] record["user"] = ownerDisplayName tags = [] if "Tags" in attrs: tags = re.split("[<>]+", attrs["Tags"]) record["tags"] = [x for x in tags if len(x) > 0] record["type"] = "question" record["questionId"] = attrs["Id"] if "CreationDate" in attrs: record["creationDate"] = attrs["CreationDate"] if "Title" in attrs: record["title"] = attrs["Title"].replace("\n", " ").replace("\r", " ") record["titleVector"] = embedding_model.encode(record["title"], normalize_embeddings=True).tolist() if "AcceptedAnswerId" in attrs: record["acceptedAnswerId"] = attrs["AcceptedAnswerId"] if "Body" in attrs: soup = BeautifulSoup(attrs["Body"], "html.parser") body = soup.get_text().replace("\n", " ").replace("\r", "") body = re.sub("\s+", " ", body) record["body"] = body myjsonfile.write(json.dumps(record, separators=(",", ":"))) myjsonfile.write("\n") if __name__ == "__main__": arg_parser = argparse.ArgumentParser( description="Script to process stack overflow posts. Filters out non-question type posts and computes vector embedding for Title fields", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) arg_parser.add_argument("file", help="Path to XML posts file") args = arg_parser.parse_args() posts_filename = args.file p = Path(posts_filename) output = p.with_suffix(".json") parser = xml.sax.make_parser() print("Preprocessing stack overflow posts") with open(output, "w") as myjsonfile: parser.setContentHandler(PostsHandler()) parser.parse(open(posts_filename, "r"))