def show_msa_info()

in src/analysis/notebook_utils.py [0:0]


def show_msa_info(
    single_chain_msas: Sequence[parsers.Msa],
    sequence_index: int):
  """Prints info and shows a plot of the deduplicated single chain MSA."""
  full_single_chain_msa = []
  for single_chain_msa in single_chain_msas:
    full_single_chain_msa.extend(single_chain_msa.sequences)

  # Deduplicate but preserve order (hence can't use set).
  deduped_full_single_chain_msa = list(dict.fromkeys(full_single_chain_msa))
  total_msa_size = len(deduped_full_single_chain_msa)
  print(f'\n{total_msa_size} unique sequences found in total for sequence '
        f'{sequence_index}\n')

  aa_map = {res: i for i, res in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')}
  msa_arr = np.array(
      [[aa_map[aa] for aa in seq] for seq in deduped_full_single_chain_msa])

  plt.figure(figsize=(12, 3))
  plt.title(f'Per-Residue Count of Non-Gap Amino Acids in the MSA for Sequence '
            f'{sequence_index}')
  plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), color='black')
  plt.ylabel('Non-Gap Count')
  plt.yticks(range(0, total_msa_size + 1, max(1, int(total_msa_size / 3))))
  plt.show()