codes/example_attention_compute.ipynb (343 lines of code) (raw):

{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "667768fd", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import numpy as np\n", "import math\n", "from einops.layers.torch import Rearrange\n", "from einops import rearrange\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 2, "id": "95e8ea0f", "metadata": {}, "outputs": [], "source": [ "## Data processing ##\n", "\n", "## We provided an example data (data/info_attention.txt) for reference. ##\n", "\n", "## Please provide the following information in that file: \n", "## (1) n x 1029bp DNA sequences, n is the number of input peaks. For each peak, the central nucleotide should be sampled at the central of that peak, and the DNA sequence length should be of 1029 bp;\n", "## (2) n x 1029bp ATAC signal values. The ATAC signal values should be from the .bigWig file, each value is at the same position of your sampled nucleotide.\n", "## (3) Please give n x IDs for reference of each peak\n", "\n", "## Please write your information in a .txt file, for example:\n", "## line 1 (ID of this peak): >chr12 135145-135516 \n", "## line 2 (1029bp DNA): ATCGATCG ... ... TCGA\n", "## line 3 (1029bp ATAC): 1.28971 1.11121 ... ... 0.01234\n", "## line 4 (next peak ID): ... ...\n", "## ... ...\n", "## input information of n peaks should have 3*n lines" ] }, { "cell_type": "code", "execution_count": 3, "id": "e51b17df", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([44, 1024]) torch.Size([44, 1024])\n" ] } ], "source": [ "from data_processing import *\n", "genome_dict = torch.load('../data/kmer_dict.pkl')\n", "\n", "raw_input_data = narrowPeak_Reader('../data/info_attention.txt')\n", "\n", "data_len = int(len(raw_input_data)/3)\n", "peak_id = list(range(data_len))\n", "dna_in = list(range(data_len))\n", "atac_in = list(range(data_len))\n", "\n", "for i_data in range (data_len): \n", " peak_id[i_data] = raw_input_data[i_data*3][0]\n", " dna_in[i_data] = pre_processing(tokenizer(raw_input_data[i_data*3+1][0], 6), genome_dict)\n", " for i in range (len(raw_input_data[i_data*3+2])):\n", " raw_input_data[i_data*3+2][i]=float(raw_input_data[i_data*3+2][i])\n", " atac_in[i_data] = pre_pro(raw_input_data[i_data*3+2], 6)\n", "dna_in = torch.tensor(dna_in)\n", "atac_in = torch.tensor(atac_in)\n", "\n", "print(dna_in.shape, atac_in.shape)" ] }, { "cell_type": "code", "execution_count": 4, "id": "661ec912", "metadata": {}, "outputs": [], "source": [ "## Specify your TSS location and strand ##\n", "## Our example gene is GATA1, the closest ATAC peak to TSS is the 20th peak, and the gene is + strand ##\n", "tss_loc = 20\n", "direction = '+'" ] }, { "cell_type": "code", "execution_count": 5, "id": "99149c3c", "metadata": {}, "outputs": [], "source": [ "## Load REformer model ##" ] }, { "cell_type": "code", "execution_count": 6, "id": "eb4e9787", "metadata": {}, "outputs": [], "source": [ "from models import *\n", "from attention import *\n", "\n", "cuda = torch.device('cuda', 0)\n", "\n", "dna_embed = torch.load(\"../pretrained_models/dna_embed.pkl\" , map_location='cpu').to(cuda)\n", "atac_embed = torch.load(\"../pretrained_models/atac_embed.pkl\" , map_location='cpu').to(cuda)\n", "pos1_embed = torch.load(\"../pretrained_models/pos1_embed.pkl\" , map_location='cpu').to(cuda)\n", "pos2_embed = torch.load(\"../pretrained_models/pos2_embed.pkl\" , map_location='cpu').to(cuda)\n", "tss_embed = torch.load(\"../pretrained_models/tss_embed.pkl\" , map_location='cpu').to(cuda)\n", "pad_embed = torch.load(\"../pretrained_models/pad_embed.pkl\" , map_location='cpu').to(cuda)\n", "encoder_1 = torch.load(\"../pretrained_models/transformer_1.pkl\", map_location='cpu').to(cuda)\n", "encoder_2 = torch.load(\"../pretrained_models/transformer_2.pkl\", map_location='cpu').to(cuda)\n", "encoder_3 = torch.load(\"../pretrained_models/transformer_3.pkl\", map_location='cpu').to(cuda)\n", "atten_pool = torch.load(\"../pretrained_models/atten_pool.pkl\" , map_location='cpu').to(cuda)\n", "ff_net = torch.load(\"../pretrained_models/feedforward.pkl\" , map_location='cpu').to(cuda)" ] }, { "cell_type": "code", "execution_count": 7, "id": "e4cab5fc", "metadata": {}, "outputs": [], "source": [ "## compute attention score ##" ] }, { "cell_type": "code", "execution_count": 8, "id": "005228d2", "metadata": {}, "outputs": [], "source": [ "dna_in = dna_in.to(cuda)\n", "sig_in = atac_in.to(cuda)\n", "\n", "pos1 = torch.ones(129, dtype=int).to(cuda)\n", "for i in range (len(pos1)):\n", " pos1[i]+=i\n", "pos2 = torch.ones(8, dtype=int).to(cuda)\n", "for i in range (len(pos2)):\n", " pos2[i]+=i\n", "pos3 = torch.zeros(150, dtype=int).to(cuda)\n", "pos3[tss_loc] = 0\n", "if direction=='+':\n", " pos3[tss_loc-1] = 1\n", " pos3[tss_loc+1] = 2\n", " for tss_i in range (tss_loc-1):\n", " pos3[tss_loc-1-tss_i-1] = pos3[tss_loc-1-tss_i]+2\n", " for tss_i in range (dna_in.shape[0]-tss_loc-2):\n", " pos3[tss_loc+1+tss_i+1] = pos3[tss_loc+1+tss_i]+2\n", "if direction=='-':\n", " pos3[tss_loc-1] = 2\n", " pos3[tss_loc+1] = 1\n", " for tss_i in range (tss_loc-1):\n", " pos3[tss_loc-1-tss_i-1] = pos3[tss_loc-1-tss_i]+2\n", " for tss_i in range (dna_in.shape[0]-tss_loc-2):\n", " pos3[tss_loc+1+tss_i+1] = pos3[tss_loc+1+tss_i]+2\n", " \n", "with torch.no_grad():\n", " \n", " CLS = dna_embed(torch.ones(dna_in.shape[0]*8, 1, dtype=int).to(cuda))\n", " x_POS_1 = pos1_embed(pos1) \n", " x_mul = dna_embed(dna_in.int().reshape(dna_in.shape[0]*8, 128)) + atac_embed(sig_in.int().reshape(dna_in.shape[0]*8, 128)) \n", " x_embed = torch.cat((CLS, x_mul), dim=1)\n", " x_enc_1 = encoder_1(x_embed + x_POS_1)[:,0,:].reshape(dna_in.shape[0],8,2048)\n", " x_POS_2 = pos2_embed(pos2)\n", " x_enc_2 = encoder_2(x_enc_1+x_POS_2)\n", " x_enc_2 = rearrange(x_enc_2, 'b n d -> b d n')\n", " x_enc_2 = atten_pool(x_enc_2)\n", " x_enc_2 = rearrange(x_enc_2, 'b d n -> b n d').squeeze(1)\n", " x_pad = pad_embed(torch.zeros(150-x_enc_2.shape[0], dtype=int).to(cuda))\n", " x_eb3 = torch.cat((x_enc_2, x_pad), dim=0)\n", " x_POS_3 = tss_embed(pos3)\n", " x_enc_3 = x_eb3 + x_POS_3\n", "\n", " attn_probs = extract_selfattention_maps(encoder_3.encoder,x_enc_3.unsqueeze(0))\n", " SM = nn.Softmax(dim=2)\n", " attention_score = SM(attn_probs[0]).mean(0).sum(0)[0:dna_in.shape[0]].tolist()" ] }, { "cell_type": "code", "execution_count": 9, "id": "5acc762e", "metadata": {}, "outputs": [], "source": [ "## print results ##" ] }, { "cell_type": "code", "execution_count": 10, "id": "b7151318", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">chrX_48648355_48648753\n", "Attention score: 1.028488039970398\n", ">chrX_48652218_48652494\n", "Attention score: 1.3975310325622559\n", ">chrX_48660471_48661224\n", "Attention score: 1.28517746925354\n", ">chrX_48675912_48677365\n", "Attention score: 1.1235225200653076\n", ">chrX_48680543_48680748\n", "Attention score: 1.3664084672927856\n", ">chrX_48683426_48683886\n", "Attention score: 1.525336503982544\n", ">chrX_48689104_48689575\n", "Attention score: 0.9983937740325928\n", ">chrX_48695968_48697301\n", "Attention score: 0.9551467299461365\n", ">chrX_48701826_48702668\n", "Attention score: 1.094896674156189\n", ">chrX_48737081_48738023\n", "Attention score: 1.2857067584991455\n", ">chrX_48750823_48751127\n", "Attention score: 1.350295066833496\n", ">chrX_48753382_48754579\n", "Attention score: 1.7550113201141357\n", ">chrX_48761672_48762478\n", "Attention score: 1.7182071208953857\n", ">chrX_48765487_48765941\n", "Attention score: 1.5139516592025757\n", ">chrX_48770352_48771155\n", "Attention score: 1.1152039766311646\n", ">chrX_48776735_48777456\n", "Attention score: 1.668673038482666\n", ">chrX_48779520_48779697\n", "Attention score: 1.3726859092712402\n", ">chrX_48782680_48783266\n", "Attention score: 1.2104418277740479\n", ">chrX_48785143_48785558\n", "Attention score: 2.014192581176758\n", ">chrX_48785717_48786295\n", "Attention score: 1.4691805839538574\n", ">chrX_48786374_48786792\n", "Attention score: 2.1485726833343506\n", ">chrX_48788441_48788864\n", "Attention score: 1.2086058855056763\n", ">chrX_48794064_48795657\n", "Attention score: 2.115185022354126\n", ">chrX_48800485_48801011\n", "Attention score: 1.7803266048431396\n", ">chrX_48801263_48801441\n", "Attention score: 1.9881824254989624\n", ">chrX_48801571_48802610\n", "Attention score: 1.839195728302002\n", ">chrX_48816491_48816877\n", "Attention score: 0.9433814287185669\n", ">chrX_48823026_48823601\n", "Attention score: 1.2529704570770264\n", ">chrX_48833554_48834232\n", "Attention score: 1.0208148956298828\n", ">chrX_48841756_48842086\n", "Attention score: 1.1173759698867798\n", ">chrX_48854965_48855186\n", "Attention score: 1.7735223770141602\n", ">chrX_48863882_48864748\n", "Attention score: 1.5016655921936035\n", ">chrX_48867502_48868359\n", "Attention score: 1.474397897720337\n", ">chrX_48876817_48877684\n", "Attention score: 1.5357447862625122\n", ">chrX_48882677_48883075\n", "Attention score: 1.4887773990631104\n", ">chrX_48883709_48884193\n", "Attention score: 0.982018232345581\n", ">chrX_48890498_48892185\n", "Attention score: 1.215963363647461\n", ">chrX_48897112_48898606\n", "Attention score: 1.4953868389129639\n", ">chrX_48903228_48903628\n", "Attention score: 1.383549451828003\n", ">chrX_48903701_48904940\n", "Attention score: 1.909240484237671\n", ">chrX_48910900_48912880\n", "Attention score: 1.3571687936782837\n", ">chrX_48917151_48917757\n", "Attention score: 1.1892554759979248\n", ">chrX_48918576_48919783\n", "Attention score: 1.1372087001800537\n", ">chrX_48919990_48920327\n", "Attention score: 1.505488395690918\n" ] } ], "source": [ "for i in range (len(attention_score)): \n", " print(peak_id[i])\n", " print('Attention score: ', attention_score[i])" ] }, { "cell_type": "code", "execution_count": null, "id": "c1d37b70", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }