codes/example_gexp_compute.ipynb (387 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "0d9893c6",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np\n",
"import math\n",
"from einops.layers.torch import Rearrange\n",
"from einops import rearrange\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0a09d737",
"metadata": {},
"outputs": [],
"source": [
"## Data processing ##\n",
"\n",
"## We provided two example data (data/info_ko_org.txt and data/info_ko_ko.txt) for reference. ##\n",
"\n",
"## Please provide the following information in that file: \n",
"## (1) n x 1029bp DNA sequences, n is the number of input peaks. For each peak, the central nucleotide should be sampled at the central of that peak, and the DNA sequence length should be of 1029 bp;\n",
"## (2) n x 1029bp ATAC signal values. The ATAC signal values should be from the .bigWig file, each value is at the same position of your sampled nucleotide.\n",
"## (3) Please give n x IDs for reference of each peak\n",
"\n",
"## Please write your information in a .txt file, foe example:\n",
"## line 1 (ID of this peak): >chr12 135145-135516 \n",
"## line 2 (1029bp DNA): ATCGATCG ... ... TCGA\n",
"## line 3 (1029bp ATAC): 1.28971 1.11121 ... ... 0.01234\n",
"## line 4 (next peak ID): ... ...\n",
"## ... ...\n",
"## input information of n peaks should have 3*n lines"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a50ee3aa",
"metadata": {},
"outputs": [],
"source": [
"## If you want to simulate the KO or KI of functional elements:\n",
"## Please provid the .txt data BEFORE and AFTER the KO or KI, and compute their gene expression respectively"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2210f26d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([44, 1024]) torch.Size([44, 1024])\n"
]
}
],
"source": [
"from data_processing import *\n",
"genome_dict = torch.load('../data/kmer_dict.pkl')\n",
"\n",
"raw_input_data = narrowPeak_Reader('../data/info_ko_org.txt')\n",
"\n",
"data_len = int(len(raw_input_data)/3)\n",
"peak_id = list(range(data_len))\n",
"dna_in = list(range(data_len))\n",
"atac_in = list(range(data_len))\n",
"\n",
"for i_data in range (data_len): \n",
" peak_id[i_data] = raw_input_data[i_data*3][0]\n",
" dna_in[i_data] = pre_processing(tokenizer(raw_input_data[i_data*3+1][0], 6), genome_dict)\n",
" for i in range (len(raw_input_data[i_data*3+2])):\n",
" raw_input_data[i_data*3+2][i]=float(raw_input_data[i_data*3+2][i])\n",
" atac_in[i_data] = pre_pro(raw_input_data[i_data*3+2], 6)\n",
"dna_in = torch.tensor(dna_in)\n",
"atac_in = torch.tensor(atac_in)\n",
"\n",
"print(dna_in.shape, atac_in.shape)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "cca09524",
"metadata": {},
"outputs": [],
"source": [
"## Specify your TSS location and strand ##\n",
"## Our example gene is GATA1, the closest ATAC peak to TSS is the 20th peak, and the gene is + strand ##\n",
"tss_loc = 20\n",
"direction = '+'"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "279ecd21",
"metadata": {},
"outputs": [],
"source": [
"## Load REformer model ##"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "d270ec6a",
"metadata": {},
"outputs": [],
"source": [
"from models import *\n",
"from attention import *\n",
"\n",
"cuda = torch.device('cuda', 0)\n",
"\n",
"dna_embed = torch.load(\"../pretrained_models/dna_embed.pkl\" , map_location='cpu').to(cuda)\n",
"atac_embed = torch.load(\"../pretrained_models/atac_embed.pkl\" , map_location='cpu').to(cuda)\n",
"pos1_embed = torch.load(\"../pretrained_models/pos1_embed.pkl\" , map_location='cpu').to(cuda)\n",
"pos2_embed = torch.load(\"../pretrained_models/pos2_embed.pkl\" , map_location='cpu').to(cuda)\n",
"tss_embed = torch.load(\"../pretrained_models/tss_embed.pkl\" , map_location='cpu').to(cuda)\n",
"pad_embed = torch.load(\"../pretrained_models/pad_embed.pkl\" , map_location='cpu').to(cuda)\n",
"encoder_1 = torch.load(\"../pretrained_models/transformer_1.pkl\", map_location='cpu').to(cuda)\n",
"encoder_2 = torch.load(\"../pretrained_models/transformer_2.pkl\", map_location='cpu').to(cuda)\n",
"encoder_3 = torch.load(\"../pretrained_models/transformer_3.pkl\", map_location='cpu').to(cuda)\n",
"atten_pool = torch.load(\"../pretrained_models/atten_pool.pkl\" , map_location='cpu').to(cuda)\n",
"ff_net = torch.load(\"../pretrained_models/feedforward.pkl\" , map_location='cpu').to(cuda)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "3560322d",
"metadata": {},
"outputs": [],
"source": [
"## compute attention score ##"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c88d9d8e",
"metadata": {},
"outputs": [],
"source": [
"dna_in = dna_in.to(cuda)\n",
"sig_in = atac_in.to(cuda)\n",
"\n",
"pos1 = torch.ones(129, dtype=int).to(cuda)\n",
"for i in range (len(pos1)):\n",
" pos1[i]+=i\n",
"pos2 = torch.ones(8, dtype=int).to(cuda)\n",
"for i in range (len(pos2)):\n",
" pos2[i]+=i\n",
"pos3 = torch.zeros(150, dtype=int).to(cuda)\n",
"pos3[tss_loc] = 0\n",
"if direction=='+':\n",
" pos3[tss_loc-1] = 1\n",
" pos3[tss_loc+1] = 2\n",
" for tss_i in range (tss_loc-1):\n",
" pos3[tss_loc-1-tss_i-1] = pos3[tss_loc-1-tss_i]+2\n",
" for tss_i in range (dna_in.shape[0]-tss_loc-2):\n",
" pos3[tss_loc+1+tss_i+1] = pos3[tss_loc+1+tss_i]+2\n",
"if direction=='-':\n",
" pos3[tss_loc-1] = 2\n",
" pos3[tss_loc+1] = 1\n",
" for tss_i in range (tss_loc-1):\n",
" pos3[tss_loc-1-tss_i-1] = pos3[tss_loc-1-tss_i]+2\n",
" for tss_i in range (dna_in.shape[0]-tss_loc-2):\n",
" pos3[tss_loc+1+tss_i+1] = pos3[tss_loc+1+tss_i]+2\n",
" \n",
"with torch.no_grad():\n",
" \n",
" CLS = dna_embed(torch.ones(dna_in.shape[0]*8, 1, dtype=int).to(cuda))\n",
" x_POS_1 = pos1_embed(pos1) \n",
" x_mul = dna_embed(dna_in.int().reshape(dna_in.shape[0]*8, 128)) + atac_embed(sig_in.int().reshape(dna_in.shape[0]*8, 128)) \n",
" x_embed = torch.cat((CLS, x_mul), dim=1)\n",
" x_enc_1 = encoder_1(x_embed + x_POS_1)[:,0,:].reshape(dna_in.shape[0],8,2048)\n",
" x_POS_2 = pos2_embed(pos2)\n",
" x_enc_2 = encoder_2(x_enc_1+x_POS_2)\n",
" x_enc_2 = rearrange(x_enc_2, 'b n d -> b d n')\n",
" x_enc_2 = atten_pool(x_enc_2)\n",
" x_enc_2 = rearrange(x_enc_2, 'b d n -> b n d').squeeze(1)\n",
" x_pad = pad_embed(torch.zeros(150-x_enc_2.shape[0], dtype=int).to(cuda))\n",
" x_eb3 = torch.cat((x_enc_2, x_pad), dim=0)\n",
" x_POS_3 = tss_embed(pos3)\n",
" x_enc_3 = encoder_3(x_eb3 + x_POS_3)\n",
" result = ff_net(x_enc_3).squeeze(1)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "48b55cf9",
"metadata": {},
"outputs": [],
"source": [
"## print results #"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "cbfd7435",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([6.4337], device='cuda:0', dtype=torch.bfloat16)\n"
]
}
],
"source": [
"print(result)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "4e46b270",
"metadata": {},
"outputs": [],
"source": [
"## Here, we additionally showed the KO of chrX_48761672_48762478 (enhancer)\n",
"## In the input .txt file (data/info_ko_ko.txt), we set the ATAC signal in that peak as zero"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df3e7eae",
"metadata": {},
"outputs": [],
"source": [
"## Please notice that if you are performing KI of functional elements, the index of TSS location might change ##\n",
"## Please re-specify your TSS location in KI experiments ##\n",
"tss_loc = 20"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "8c012ee3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([44, 1024]) torch.Size([44, 1024])\n"
]
}
],
"source": [
"raw_input_data = narrowPeak_Reader('../data/info_ko_ko.txt')\n",
"\n",
"data_len = int(len(raw_input_data)/3)\n",
"peak_id = list(range(data_len))\n",
"dna_in = list(range(data_len))\n",
"atac_in = list(range(data_len))\n",
"\n",
"for i_data in range (data_len): \n",
" peak_id[i_data] = raw_input_data[i_data*3][0]\n",
" dna_in[i_data] = pre_processing(tokenizer(raw_input_data[i_data*3+1][0], 6), genome_dict)\n",
" for i in range (len(raw_input_data[i_data*3+2])):\n",
" raw_input_data[i_data*3+2][i]=float(raw_input_data[i_data*3+2][i])\n",
" atac_in[i_data] = pre_pro(raw_input_data[i_data*3+2], 6)\n",
"dna_in = torch.tensor(dna_in)\n",
"atac_in = torch.tensor(atac_in)\n",
"\n",
"print(dna_in.shape, atac_in.shape)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "a1814d9e",
"metadata": {},
"outputs": [],
"source": [
"dna_in = dna_in.to(cuda)\n",
"sig_in = atac_in.to(cuda)\n",
"\n",
"pos1 = torch.ones(129, dtype=int).to(cuda)\n",
"for i in range (len(pos1)):\n",
" pos1[i]+=i\n",
"pos2 = torch.ones(8, dtype=int).to(cuda)\n",
"for i in range (len(pos2)):\n",
" pos2[i]+=i\n",
"pos3 = torch.zeros(150, dtype=int).to(cuda)\n",
"pos3[tss_loc] = 0\n",
"if direction=='+':\n",
" pos3[tss_loc-1] = 1\n",
" pos3[tss_loc+1] = 2\n",
" for tss_i in range (tss_loc-1):\n",
" pos3[tss_loc-1-tss_i-1] = pos3[tss_loc-1-tss_i]+2\n",
" for tss_i in range (dna_in.shape[0]-tss_loc-2):\n",
" pos3[tss_loc+1+tss_i+1] = pos3[tss_loc+1+tss_i]+2\n",
"if direction=='-':\n",
" pos3[tss_loc-1] = 2\n",
" pos3[tss_loc+1] = 1\n",
" for tss_i in range (tss_loc-1):\n",
" pos3[tss_loc-1-tss_i-1] = pos3[tss_loc-1-tss_i]+2\n",
" for tss_i in range (dna_in.shape[0]-tss_loc-2):\n",
" pos3[tss_loc+1+tss_i+1] = pos3[tss_loc+1+tss_i]+2\n",
" \n",
"with torch.no_grad():\n",
" \n",
" CLS = dna_embed(torch.ones(dna_in.shape[0]*8, 1, dtype=int).to(cuda))\n",
" x_POS_1 = pos1_embed(pos1) \n",
" x_mul = dna_embed(dna_in.int().reshape(dna_in.shape[0]*8, 128)) + atac_embed(sig_in.int().reshape(dna_in.shape[0]*8, 128)) \n",
" x_embed = torch.cat((CLS, x_mul), dim=1)\n",
" x_enc_1 = encoder_1(x_embed + x_POS_1)[:,0,:].reshape(dna_in.shape[0],8,2048)\n",
" x_POS_2 = pos2_embed(pos2)\n",
" x_enc_2 = encoder_2(x_enc_1+x_POS_2)\n",
" x_enc_2 = rearrange(x_enc_2, 'b n d -> b d n')\n",
" x_enc_2 = atten_pool(x_enc_2)\n",
" x_enc_2 = rearrange(x_enc_2, 'b d n -> b n d').squeeze(1)\n",
" x_pad = pad_embed(torch.zeros(150-x_enc_2.shape[0], dtype=int).to(cuda))\n",
" x_eb3 = torch.cat((x_enc_2, x_pad), dim=0)\n",
" x_POS_3 = tss_embed(pos3)\n",
" x_enc_3 = encoder_3(x_eb3 + x_POS_3)\n",
" result = ff_net(x_enc_3).squeeze(1)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "dd0281e9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([2.2318], device='cuda:0', dtype=torch.bfloat16)\n"
]
}
],
"source": [
"print(result)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5ee67c1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}