in src/dataset.py [0:0]
def __iter__(self) -> Iterator[dict]:
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
# single-process data loading, return the full iterator
protein_ids = [doc["_id"] for doc in self.docs]
else: # in a worker process
# split workload
start = 0
end = len(self.docs)
per_worker = int(
math.ceil((end - start) / float(worker_info.num_workers))
)
worker_id = worker_info.id
iter_start = start + worker_id * per_worker
iter_end = min(iter_start + per_worker, end)
protein_ids = [
doc["_id"] for doc in self.docs[iter_start:iter_end]
]
# retrieve a list of proteins by _id from DocDB
with MongoClient(self.db_uri) as client:
collection = client[self.db_name][self.collection_name]
cur = collection.find(
{"_id": {"$in": protein_ids}},
projection={"coords": True, "seq": True},
)
return (
(
convert_to_graph(protein, k=self.k),
self.labels[protein["_id"]],
)
for protein in cur
)