def __iter__()

in src/dataset.py [0:0]


    def __iter__(self) -> Iterator[dict]:
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            # single-process data loading, return the full iterator
            protein_ids = [doc["_id"] for doc in self.docs]

        else:  # in a worker process
            # split workload
            start = 0
            end = len(self.docs)
            per_worker = int(
                math.ceil((end - start) / float(worker_info.num_workers))
            )
            worker_id = worker_info.id
            iter_start = start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, end)

            protein_ids = [
                doc["_id"] for doc in self.docs[iter_start:iter_end]
            ]

        # retrieve a list of proteins by _id from DocDB
        with MongoClient(self.db_uri) as client:
            collection = client[self.db_name][self.collection_name]
            cur = collection.find(
                {"_id": {"$in": protein_ids}},
                projection={"coords": True, "seq": True},
            )
            return (
                (
                    convert_to_graph(protein, k=self.k),
                    self.labels[protein["_id"]],
                )
                for protein in cur
            )