def __call__()

in python/dgllife/utils/featurizers.py [0:0]


    def __call__(self, mol):
        """Featurizes the input molecule.

        Parameters
        ----------
        mol : rdkit.Chem.rdchem.Mol
            RDKit molecule instance.

        Returns
        -------
        dict
            Mapping self._edge_data_field to a float32 tensor of shape (N, M), where
            N is the number of atom pairs and M is the feature size depending on max_length.
        """

        n_atoms = mol.GetNumAtoms()
        # To get the shortest paths between two nodes.
        paths_dict = {
            (i, j): Chem.rdmolops.GetShortestPath(mol, i, j)
            for i in range(n_atoms)
            for j in range(n_atoms)
            if i != j
            }
        # To get info if two nodes belong to the same ring.
        rings_dict = {}
        ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
        for ring in ssr:
            ring_sz = len(ring)
            is_aromatic = True
            for atom_idx in ring:
                if not mol.GetAtoms()[atom_idx].GetIsAromatic():
                    is_aromatic = False
                    break
            for ring_idx, atom_idx in enumerate(ring):
                for other_idx in ring[ring_idx:]:
                    atom_pair = self.ordered_pair(atom_idx, other_idx)
                    if atom_pair not in rings_dict:
                        rings_dict[atom_pair] = [(ring_sz, is_aromatic)]
                    else:
                        if (ring_sz, is_aromatic) not in rings_dict[atom_pair]:
                            rings_dict[atom_pair].append((ring_sz, is_aromatic))
        # Featurizer
        feats = []
        for i in range(n_atoms):
            for j in range(n_atoms):

                if (i, j) not in paths_dict:
                    feats.append(np.zeros(7*self.max_length + 7))
                    continue
                ring_info = rings_dict.get(self.ordered_pair(i, j), [])
                feats.append(self.bond_features(mol, paths_dict[(i, j)], ring_info))

        return {self.bond_data_field: torch.tensor(feats).float()}