codes/example_attention_compute.ipynb (343 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "667768fd",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np\n",
"import math\n",
"from einops.layers.torch import Rearrange\n",
"from einops import rearrange\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "95e8ea0f",
"metadata": {},
"outputs": [],
"source": [
"## Data processing ##\n",
"\n",
"## We provided an example data (data/info_attention.txt) for reference. ##\n",
"\n",
"## Please provide the following information in that file: \n",
"## (1) n x 1029bp DNA sequences, n is the number of input peaks. For each peak, the central nucleotide should be sampled at the central of that peak, and the DNA sequence length should be of 1029 bp;\n",
"## (2) n x 1029bp ATAC signal values. The ATAC signal values should be from the .bigWig file, each value is at the same position of your sampled nucleotide.\n",
"## (3) Please give n x IDs for reference of each peak\n",
"\n",
"## Please write your information in a .txt file, for example:\n",
"## line 1 (ID of this peak): >chr12 135145-135516 \n",
"## line 2 (1029bp DNA): ATCGATCG ... ... TCGA\n",
"## line 3 (1029bp ATAC): 1.28971 1.11121 ... ... 0.01234\n",
"## line 4 (next peak ID): ... ...\n",
"## ... ...\n",
"## input information of n peaks should have 3*n lines"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e51b17df",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([44, 1024]) torch.Size([44, 1024])\n"
]
}
],
"source": [
"from data_processing import *\n",
"genome_dict = torch.load('../data/kmer_dict.pkl')\n",
"\n",
"raw_input_data = narrowPeak_Reader('../data/info_attention.txt')\n",
"\n",
"data_len = int(len(raw_input_data)/3)\n",
"peak_id = list(range(data_len))\n",
"dna_in = list(range(data_len))\n",
"atac_in = list(range(data_len))\n",
"\n",
"for i_data in range (data_len): \n",
" peak_id[i_data] = raw_input_data[i_data*3][0]\n",
" dna_in[i_data] = pre_processing(tokenizer(raw_input_data[i_data*3+1][0], 6), genome_dict)\n",
" for i in range (len(raw_input_data[i_data*3+2])):\n",
" raw_input_data[i_data*3+2][i]=float(raw_input_data[i_data*3+2][i])\n",
" atac_in[i_data] = pre_pro(raw_input_data[i_data*3+2], 6)\n",
"dna_in = torch.tensor(dna_in)\n",
"atac_in = torch.tensor(atac_in)\n",
"\n",
"print(dna_in.shape, atac_in.shape)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "661ec912",
"metadata": {},
"outputs": [],
"source": [
"## Specify your TSS location and strand ##\n",
"## Our example gene is GATA1, the closest ATAC peak to TSS is the 20th peak, and the gene is + strand ##\n",
"tss_loc = 20\n",
"direction = '+'"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "99149c3c",
"metadata": {},
"outputs": [],
"source": [
"## Load REformer model ##"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "eb4e9787",
"metadata": {},
"outputs": [],
"source": [
"from models import *\n",
"from attention import *\n",
"\n",
"cuda = torch.device('cuda', 0)\n",
"\n",
"dna_embed = torch.load(\"../pretrained_models/dna_embed.pkl\" , map_location='cpu').to(cuda)\n",
"atac_embed = torch.load(\"../pretrained_models/atac_embed.pkl\" , map_location='cpu').to(cuda)\n",
"pos1_embed = torch.load(\"../pretrained_models/pos1_embed.pkl\" , map_location='cpu').to(cuda)\n",
"pos2_embed = torch.load(\"../pretrained_models/pos2_embed.pkl\" , map_location='cpu').to(cuda)\n",
"tss_embed = torch.load(\"../pretrained_models/tss_embed.pkl\" , map_location='cpu').to(cuda)\n",
"pad_embed = torch.load(\"../pretrained_models/pad_embed.pkl\" , map_location='cpu').to(cuda)\n",
"encoder_1 = torch.load(\"../pretrained_models/transformer_1.pkl\", map_location='cpu').to(cuda)\n",
"encoder_2 = torch.load(\"../pretrained_models/transformer_2.pkl\", map_location='cpu').to(cuda)\n",
"encoder_3 = torch.load(\"../pretrained_models/transformer_3.pkl\", map_location='cpu').to(cuda)\n",
"atten_pool = torch.load(\"../pretrained_models/atten_pool.pkl\" , map_location='cpu').to(cuda)\n",
"ff_net = torch.load(\"../pretrained_models/feedforward.pkl\" , map_location='cpu').to(cuda)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e4cab5fc",
"metadata": {},
"outputs": [],
"source": [
"## compute attention score ##"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "005228d2",
"metadata": {},
"outputs": [],
"source": [
"dna_in = dna_in.to(cuda)\n",
"sig_in = atac_in.to(cuda)\n",
"\n",
"pos1 = torch.ones(129, dtype=int).to(cuda)\n",
"for i in range (len(pos1)):\n",
" pos1[i]+=i\n",
"pos2 = torch.ones(8, dtype=int).to(cuda)\n",
"for i in range (len(pos2)):\n",
" pos2[i]+=i\n",
"pos3 = torch.zeros(150, dtype=int).to(cuda)\n",
"pos3[tss_loc] = 0\n",
"if direction=='+':\n",
" pos3[tss_loc-1] = 1\n",
" pos3[tss_loc+1] = 2\n",
" for tss_i in range (tss_loc-1):\n",
" pos3[tss_loc-1-tss_i-1] = pos3[tss_loc-1-tss_i]+2\n",
" for tss_i in range (dna_in.shape[0]-tss_loc-2):\n",
" pos3[tss_loc+1+tss_i+1] = pos3[tss_loc+1+tss_i]+2\n",
"if direction=='-':\n",
" pos3[tss_loc-1] = 2\n",
" pos3[tss_loc+1] = 1\n",
" for tss_i in range (tss_loc-1):\n",
" pos3[tss_loc-1-tss_i-1] = pos3[tss_loc-1-tss_i]+2\n",
" for tss_i in range (dna_in.shape[0]-tss_loc-2):\n",
" pos3[tss_loc+1+tss_i+1] = pos3[tss_loc+1+tss_i]+2\n",
" \n",
"with torch.no_grad():\n",
" \n",
" CLS = dna_embed(torch.ones(dna_in.shape[0]*8, 1, dtype=int).to(cuda))\n",
" x_POS_1 = pos1_embed(pos1) \n",
" x_mul = dna_embed(dna_in.int().reshape(dna_in.shape[0]*8, 128)) + atac_embed(sig_in.int().reshape(dna_in.shape[0]*8, 128)) \n",
" x_embed = torch.cat((CLS, x_mul), dim=1)\n",
" x_enc_1 = encoder_1(x_embed + x_POS_1)[:,0,:].reshape(dna_in.shape[0],8,2048)\n",
" x_POS_2 = pos2_embed(pos2)\n",
" x_enc_2 = encoder_2(x_enc_1+x_POS_2)\n",
" x_enc_2 = rearrange(x_enc_2, 'b n d -> b d n')\n",
" x_enc_2 = atten_pool(x_enc_2)\n",
" x_enc_2 = rearrange(x_enc_2, 'b d n -> b n d').squeeze(1)\n",
" x_pad = pad_embed(torch.zeros(150-x_enc_2.shape[0], dtype=int).to(cuda))\n",
" x_eb3 = torch.cat((x_enc_2, x_pad), dim=0)\n",
" x_POS_3 = tss_embed(pos3)\n",
" x_enc_3 = x_eb3 + x_POS_3\n",
"\n",
" attn_probs = extract_selfattention_maps(encoder_3.encoder,x_enc_3.unsqueeze(0))\n",
" SM = nn.Softmax(dim=2)\n",
" attention_score = SM(attn_probs[0]).mean(0).sum(0)[0:dna_in.shape[0]].tolist()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5acc762e",
"metadata": {},
"outputs": [],
"source": [
"## print results ##"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "b7151318",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
">chrX_48648355_48648753\n",
"Attention score: 1.028488039970398\n",
">chrX_48652218_48652494\n",
"Attention score: 1.3975310325622559\n",
">chrX_48660471_48661224\n",
"Attention score: 1.28517746925354\n",
">chrX_48675912_48677365\n",
"Attention score: 1.1235225200653076\n",
">chrX_48680543_48680748\n",
"Attention score: 1.3664084672927856\n",
">chrX_48683426_48683886\n",
"Attention score: 1.525336503982544\n",
">chrX_48689104_48689575\n",
"Attention score: 0.9983937740325928\n",
">chrX_48695968_48697301\n",
"Attention score: 0.9551467299461365\n",
">chrX_48701826_48702668\n",
"Attention score: 1.094896674156189\n",
">chrX_48737081_48738023\n",
"Attention score: 1.2857067584991455\n",
">chrX_48750823_48751127\n",
"Attention score: 1.350295066833496\n",
">chrX_48753382_48754579\n",
"Attention score: 1.7550113201141357\n",
">chrX_48761672_48762478\n",
"Attention score: 1.7182071208953857\n",
">chrX_48765487_48765941\n",
"Attention score: 1.5139516592025757\n",
">chrX_48770352_48771155\n",
"Attention score: 1.1152039766311646\n",
">chrX_48776735_48777456\n",
"Attention score: 1.668673038482666\n",
">chrX_48779520_48779697\n",
"Attention score: 1.3726859092712402\n",
">chrX_48782680_48783266\n",
"Attention score: 1.2104418277740479\n",
">chrX_48785143_48785558\n",
"Attention score: 2.014192581176758\n",
">chrX_48785717_48786295\n",
"Attention score: 1.4691805839538574\n",
">chrX_48786374_48786792\n",
"Attention score: 2.1485726833343506\n",
">chrX_48788441_48788864\n",
"Attention score: 1.2086058855056763\n",
">chrX_48794064_48795657\n",
"Attention score: 2.115185022354126\n",
">chrX_48800485_48801011\n",
"Attention score: 1.7803266048431396\n",
">chrX_48801263_48801441\n",
"Attention score: 1.9881824254989624\n",
">chrX_48801571_48802610\n",
"Attention score: 1.839195728302002\n",
">chrX_48816491_48816877\n",
"Attention score: 0.9433814287185669\n",
">chrX_48823026_48823601\n",
"Attention score: 1.2529704570770264\n",
">chrX_48833554_48834232\n",
"Attention score: 1.0208148956298828\n",
">chrX_48841756_48842086\n",
"Attention score: 1.1173759698867798\n",
">chrX_48854965_48855186\n",
"Attention score: 1.7735223770141602\n",
">chrX_48863882_48864748\n",
"Attention score: 1.5016655921936035\n",
">chrX_48867502_48868359\n",
"Attention score: 1.474397897720337\n",
">chrX_48876817_48877684\n",
"Attention score: 1.5357447862625122\n",
">chrX_48882677_48883075\n",
"Attention score: 1.4887773990631104\n",
">chrX_48883709_48884193\n",
"Attention score: 0.982018232345581\n",
">chrX_48890498_48892185\n",
"Attention score: 1.215963363647461\n",
">chrX_48897112_48898606\n",
"Attention score: 1.4953868389129639\n",
">chrX_48903228_48903628\n",
"Attention score: 1.383549451828003\n",
">chrX_48903701_48904940\n",
"Attention score: 1.909240484237671\n",
">chrX_48910900_48912880\n",
"Attention score: 1.3571687936782837\n",
">chrX_48917151_48917757\n",
"Attention score: 1.1892554759979248\n",
">chrX_48918576_48919783\n",
"Attention score: 1.1372087001800537\n",
">chrX_48919990_48920327\n",
"Attention score: 1.505488395690918\n"
]
}
],
"source": [
"for i in range (len(attention_score)): \n",
" print(peak_id[i])\n",
" print('Attention score: ', attention_score[i])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1d37b70",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}