in data.py [0:0]
def __getitem__(self, index, forward=False):
FLAGS = self.FLAGS
if FLAGS.single and not forward:
index = 0
FLAGS = self.FLAGS
pickle_file = self.files[index]
# node_embed: D x 6
(node_embed,) = pickle.load(open(pickle_file, "rb"))
node_embed_original = node_embed
# Remove proteins with small numbers of atoms
if node_embed.shape[0] < 20:
return self.__getitem__((index + 1) % len(self.files), forward=True)
# Remove invalid proteins
if (
node_embed.max(axis=0)[2] >= 21
or node_embed.max(axis=0)[0] >= 20
or node_embed.max(axis=0)[1] >= 5
):
return self.__getitem__((index + 1) % len(self.files), forward=True)
par, child, pos, pos_exist, res, chis_valid = parse_dense_format(node_embed)
if par is None:
return self.__getitem__((index + 1) % len(self.files), forward=True)
if len(res) < 5:
return self.__getitem__((index + 1) % len(self.files), forward=True)
angles = compute_dihedral(par, child, pos, pos_exist)
tries = 0
perm = np.random.permutation(np.arange(1, len(res) - 1))
select_idxs = []
while True:
# Randomly sample an amino acid that are not the first and last amino acid
idx = perm[tries]
if res[idx] == "gly" or res[idx] == "ala":
idx = random.randint(1, len(res) - 2)
else:
select_idxs.append(idx)
if len(select_idxs) == FLAGS.multisample:
break
tries += 1
if tries > 1000 or tries == perm.shape[0]:
return self.__getitem__((index + 1) % len(self.files), forward=True)
node_embeds = []
node_embeds_negatives = []
select_atom_idxs = []
select_atom_masks = []
select_chis_valids = []
select_ancestors = []
for idx in select_idxs:
neg_samples = []
gt_chis = [(angles[idx, 4:8], chis_valid[idx, :4])]
neg_chis = []
# Choose number of negative samples
if FLAGS.train and self.split in ["val", "test"]:
neg_sample = 150
else:
neg_sample = FLAGS.neg_sample
atom_idxs = []
atoms_mask = []
chis_valids = []
ancestors = []
if self.split == "test":
dist = np.sqrt(np.square(pos[idx : idx + 1, 2] - pos[:, 2]).sum(axis=1))
neighbors = (dist < 10).sum()
# Choose different tresholds of sampling dependent on whether an atom is dense
# or not
if neighbors < 24:
tresh = 0.95
else:
tresh = 0.98
if self.weighted_gauss:
chis_list = interpolated_sample_normal(
self.db,
angles[idx, 1],
angles[idx, 2],
res[idx],
neg_sample,
uniform=self.uniform,
)
elif self.gmm:
chis_list = mixture_sample_normal(
self.db,
angles[idx, 1],
angles[idx, 2],
res[idx],
neg_sample,
uniform=self.uniform,
)
else:
chis_list = exhaustive_sample(
self.db,
angles[idx, 1],
angles[idx, 2],
res[idx],
tresh=tresh,
chi_mean=self.chi_mean,
)
if len(chis_list) < neg_sample:
repeat = neg_sample // len(chis_list) + 1
chis_list = chis_list * repeat
random.shuffle(chis_list)
else:
dist = np.sqrt(np.square(pos[idx : idx + 1, 2] - pos[:, 2]).sum(axis=1))
neighbors = (dist < 10).sum()
if neighbors < 24:
tresh = 1.0
else:
tresh = 1.0
if self.weighted_gauss:
chis_list = interpolated_sample_normal(
self.db,
angles[idx, 1],
angles[idx, 2],
res[idx],
neg_sample,
uniform=self.uniform,
)
elif self.gmm:
chis_list = mixture_sample_normal(
self.db,
angles[idx, 1],
angles[idx, 2],
res[idx],
neg_sample,
uniform=self.uniform,
)
else:
chis_list = exhaustive_sample(
self.db,
angles[idx, 1],
angles[idx, 2],
res[idx],
tresh=tresh,
chi_mean=self.chi_mean,
)
if len(chis_list) < neg_sample:
repeat = neg_sample // len(chis_list) + 1
chis_list = chis_list * repeat
random.shuffle(chis_list)
for i in range(neg_sample):
chis_target = angles[:, 4:8].copy()
chis = chis_list[i]
chis_target[idx] = (
chis * chis_valid[idx, :4] + (1 - chis_valid[idx, :4]) * chis_target[idx]
)
pos_new = rotate_dihedral_fast(
angles, par, child, pos, pos_exist, chis_target, chis_valid, idx
)
node_neg_embed = reencode_dense_format(node_embed, pos_new, pos_exist)
neg_samples.append(node_neg_embed)
neg_chis.append((chis_target[idx], chis_valid[idx, :4]))
nelem = pos_exist[:idx].sum()
offset = pos_exist[idx].sum()
mask = np.zeros(20)
mask[:offset] = 1
atom_idxs.append(
np.concatenate(
[np.arange(nelem, nelem + offset), np.ones(20 - offset) * (nelem)]
)
)
atoms_mask.append(mask)
chis_valids.append(chis_valid[idx, :4].copy())
ancestors.append(np.stack([par[idx], child[idx]], axis=0))
node_embed_negative = np.array(neg_samples)
pos_chosen = pos[idx, 4]
atoms_mask = np.array(atoms_mask)
atom_idxs = np.array(atom_idxs)
chis_valids = np.array(chis_valids)
ancestors = np.array(ancestors)
# Choose the closest atoms to the chosen locaiton:
close_idx = np.argsort(np.square(node_embed[:, -3:] - pos_chosen).sum(axis=1))
node_embed_short = node_embed[close_idx[: FLAGS.max_size]].copy()
pos_chosen = pos_new[idx, 4]
close_idx_neg = np.argsort(
np.square(node_embed_negative[:, :, -3:] - pos_chosen).sum(axis=2), axis=1
)
# Compute the corresponding indices for atom_idxs
# Get the position of each index ik
pos_code = np.argsort(close_idx_neg, axis=1)
choose_idx = np.take_along_axis(pos_code, atom_idxs.astype(np.int32), axis=1)
if choose_idx.max() >= FLAGS.max_size:
return self.__getitem__((index + 1) % len(self.files), forward=True)
node_embed_negative = np.take_along_axis(
node_embed_negative, close_idx_neg[:, : FLAGS.max_size, None], axis=1
)
# Normalize each coordinate of node_embed to have x, y, z coordinate to be equal 0
node_embed_short[:, -3:] = node_embed_short[:, -3:] - np.mean(
node_embed_short[:, -3:], axis=0
)
node_embed_negative[:, :, -3:] = node_embed_negative[:, :, -3:] - np.mean(
node_embed_negative[:, :, -3:], axis=1, keepdims=True
)
if FLAGS.augment:
# Now rotate all elements
rot_matrix = self.so3.rvs(1)
node_embed_short[:, -3:] = np.matmul(node_embed_short[:, -3:], rot_matrix)
rot_matrix_neg = self.so3.rvs(node_embed_negative.shape[0])
node_embed_negative[:, :, -3:] = np.matmul(
node_embed_negative[:, :, -3:], rot_matrix_neg
)
# # Additionally scale values to be in the same scale
node_embed_short[:, -3:] = node_embed_short[:, -3:] / 10.0
node_embed_negative[:, :, -3:] = node_embed_negative[:, :, -3:] / 10.0
# Augment the data with random rotations
node_embed_short = torch.from_numpy(node_embed_short).float()
node_embed_negative = torch.from_numpy(node_embed_negative).float()
if self.split == "train":
node_embeds.append(node_embed_short)
node_embeds_negatives.append(node_embed_negative)
elif self.split in ["val", "test"]:
return node_embed_short, node_embed_negative, gt_chis, neg_chis, res[idx]
return node_embeds, node_embeds_negatives