in vis_sandbox.py [0:0]
def new_model(model, FLAGS, node_embed):
BATCH_SIZE = 120
pdb_name = FLAGS.pdb_name #'6mdw.0'
pickle_file = f"/private/home/yilundu/dataset/mmcif/mmCIF/{pdb_name[1:3]}/{pdb_name}.p"
(node_embed,) = pickle.load(open(pickle_file, "rb"))
par, child, pos, pos_exist, res, chis_valid = parse_dense_format(node_embed)
angles = compute_dihedral(par, child, pos, pos_exist)
chis_target_initial = angles[
:, 4:8
].copy() # dihedral for backbone (:4); dihedral for sidechain (4:8)
NUM_RES = len(res)
all_energies = np.empty((NUM_RES, 4, 360)) # 4 is number of possible chi angles
surface_core_type = []
for idx in range(NUM_RES):
dist = np.sqrt(np.square(pos[idx : idx + 1, 2] - pos[:, 2]).sum(axis=1))
neighbors = (dist < 10).sum()
if neighbors >= 24:
surface_core_type.append("core")
elif neighbors <= 16:
surface_core_type.append("surface")
else:
surface_core_type.append("unlabeled")
for idx in tqdm(range(NUM_RES)):
for chi_num in range(4):
if not chis_valid[idx, chi_num]:
continue
# init_angle = chis_target[idx, chi_num]
for angle_deltas in batch(range(-180, 180, 3), BATCH_SIZE):
pre_rot_node_embed_short = []
for angle_delta in angle_deltas:
chis_target = chis_target_initial.copy() # make a local copy
# modify the angle by angle_delta amount. rotate to chis_target
chis_target[
idx, chi_num
] += angle_delta # Set the specific chi angle to be the sampled value
# pos_new is n residues x 20 atoms x 3 (xyz)
pos_new = rotate_dihedral_fast(
angles, par, child, pos, pos_exist, chis_target, chis_valid, idx
)
node_neg_embed = reencode_dense_format(node_embed, pos_new, pos_exist)
# sort the atoms by how far away they are
# sort key is the first atom on the sidechain
pos_chosen = pos_new[idx, 4]
close_idx = np.argsort(
np.square(node_neg_embed[:, -3:] - pos_chosen).sum(axis=1)
)
# Grab the 64 closest atoms
node_embed_short = node_neg_embed[close_idx[: FLAGS.max_size]].copy()
# Normalize each coordinate of node_embed to have x, y, z coordinate to be equal 0
node_embed_short[:, -3:] = node_embed_short[:, -3:] - np.mean(
node_embed_short[:, -3:], axis=0
)
node_embed_short = torch.from_numpy(node_embed_short).float().cuda()
pre_rot_node_embed_short.append(node_embed_short.unsqueeze(0))
pre_rot_node_embed_short = torch.stack(pre_rot_node_embed_short)
# Now rotate all elements
n_rotations = 100
so3 = special_ortho_group(3)
rot_matrix = so3.rvs(n_rotations) # n x 3 x 3
node_embed_short = pre_rot_node_embed_short.repeat(1, n_rotations, 1, 1)
rot_matrix = torch.from_numpy(rot_matrix).float().cuda()
node_embed_short[:, :, :, -3:] = torch.matmul(
node_embed_short[:, :, :, -3:], rot_matrix
) # (batch_size, n_rotations, 64, 20)
# Compute the energies for the n_rotations * batch_size for this window of 64 atoms.
# Batch the first two dimensions, then pull them apart aftewrads.
node_embed_short = node_embed_short.reshape(
node_embed_short.shape[0] * node_embed_short.shape[1],
*node_embed_short.shape[2:],
)
energies = model.forward(node_embed_short) # (12000, 1)
# divide the batch dimension by the 10 things we just did
energies = energies.reshape(BATCH_SIZE, -1) # (10, 200)
# Average the energy across the n_rotations, but keeping batch-wise seperate
energies = energies.mean(1) # (10, 1)
# Save the result
all_energies[idx, chi_num, angle_deltas] = energies.cpu().numpy()
# Can use these for processing later.
avg_chi_angle_energy = (all_energies * chis_valid[:NUM_RES, :4, None]).sum(0) / np.expand_dims(
chis_valid[:NUM_RES, :4].sum(0), 1
) # normalize by how many times each chi angle occurs
output = {
"all_energies": all_energies,
"chis_valid": chis_valid,
"chis_target_initial": chis_target_initial,
"avg_chi_angle_energy": avg_chi_angle_energy, # make four plots from this (4, 360),
"res": res,
"surface_core_type": surface_core_type,
}
# Dump the output
output_path = osp.join(FLAGS.outdir, f"{pdb_name}_rot_energies.p")
if not osp.exists(FLAGS.outdir):
os.makedirs(FLAGS.outdir)
pickle.dump(output, open(output_path, "wb"))