codes/example_gexp_compute.ipynb (387 lines of code) (raw):

{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "0d9893c6", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import numpy as np\n", "import math\n", "from einops.layers.torch import Rearrange\n", "from einops import rearrange\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 2, "id": "0a09d737", "metadata": {}, "outputs": [], "source": [ "## Data processing ##\n", "\n", "## We provided two example data (data/info_ko_org.txt and data/info_ko_ko.txt) for reference. ##\n", "\n", "## Please provide the following information in that file: \n", "## (1) n x 1029bp DNA sequences, n is the number of input peaks. For each peak, the central nucleotide should be sampled at the central of that peak, and the DNA sequence length should be of 1029 bp;\n", "## (2) n x 1029bp ATAC signal values. The ATAC signal values should be from the .bigWig file, each value is at the same position of your sampled nucleotide.\n", "## (3) Please give n x IDs for reference of each peak\n", "\n", "## Please write your information in a .txt file, foe example:\n", "## line 1 (ID of this peak): >chr12 135145-135516 \n", "## line 2 (1029bp DNA): ATCGATCG ... ... TCGA\n", "## line 3 (1029bp ATAC): 1.28971 1.11121 ... ... 0.01234\n", "## line 4 (next peak ID): ... ...\n", "## ... ...\n", "## input information of n peaks should have 3*n lines" ] }, { "cell_type": "code", "execution_count": 3, "id": "a50ee3aa", "metadata": {}, "outputs": [], "source": [ "## If you want to simulate the KO or KI of functional elements:\n", "## Please provid the .txt data BEFORE and AFTER the KO or KI, and compute their gene expression respectively" ] }, { "cell_type": "code", "execution_count": 4, "id": "2210f26d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([44, 1024]) torch.Size([44, 1024])\n" ] } ], "source": [ "from data_processing import *\n", "genome_dict = torch.load('../data/kmer_dict.pkl')\n", "\n", "raw_input_data = narrowPeak_Reader('../data/info_ko_org.txt')\n", "\n", "data_len = int(len(raw_input_data)/3)\n", "peak_id = list(range(data_len))\n", "dna_in = list(range(data_len))\n", "atac_in = list(range(data_len))\n", "\n", "for i_data in range (data_len): \n", " peak_id[i_data] = raw_input_data[i_data*3][0]\n", " dna_in[i_data] = pre_processing(tokenizer(raw_input_data[i_data*3+1][0], 6), genome_dict)\n", " for i in range (len(raw_input_data[i_data*3+2])):\n", " raw_input_data[i_data*3+2][i]=float(raw_input_data[i_data*3+2][i])\n", " atac_in[i_data] = pre_pro(raw_input_data[i_data*3+2], 6)\n", "dna_in = torch.tensor(dna_in)\n", "atac_in = torch.tensor(atac_in)\n", "\n", "print(dna_in.shape, atac_in.shape)" ] }, { "cell_type": "code", "execution_count": 5, "id": "cca09524", "metadata": {}, "outputs": [], "source": [ "## Specify your TSS location and strand ##\n", "## Our example gene is GATA1, the closest ATAC peak to TSS is the 20th peak, and the gene is + strand ##\n", "tss_loc = 20\n", "direction = '+'" ] }, { "cell_type": "code", "execution_count": 6, "id": "279ecd21", "metadata": {}, "outputs": [], "source": [ "## Load REformer model ##" ] }, { "cell_type": "code", "execution_count": 7, "id": "d270ec6a", "metadata": {}, "outputs": [], "source": [ "from models import *\n", "from attention import *\n", "\n", "cuda = torch.device('cuda', 0)\n", "\n", "dna_embed = torch.load(\"../pretrained_models/dna_embed.pkl\" , map_location='cpu').to(cuda)\n", "atac_embed = torch.load(\"../pretrained_models/atac_embed.pkl\" , map_location='cpu').to(cuda)\n", "pos1_embed = torch.load(\"../pretrained_models/pos1_embed.pkl\" , map_location='cpu').to(cuda)\n", "pos2_embed = torch.load(\"../pretrained_models/pos2_embed.pkl\" , map_location='cpu').to(cuda)\n", "tss_embed = torch.load(\"../pretrained_models/tss_embed.pkl\" , map_location='cpu').to(cuda)\n", "pad_embed = torch.load(\"../pretrained_models/pad_embed.pkl\" , map_location='cpu').to(cuda)\n", "encoder_1 = torch.load(\"../pretrained_models/transformer_1.pkl\", map_location='cpu').to(cuda)\n", "encoder_2 = torch.load(\"../pretrained_models/transformer_2.pkl\", map_location='cpu').to(cuda)\n", "encoder_3 = torch.load(\"../pretrained_models/transformer_3.pkl\", map_location='cpu').to(cuda)\n", "atten_pool = torch.load(\"../pretrained_models/atten_pool.pkl\" , map_location='cpu').to(cuda)\n", "ff_net = torch.load(\"../pretrained_models/feedforward.pkl\" , map_location='cpu').to(cuda)" ] }, { "cell_type": "code", "execution_count": 8, "id": "3560322d", "metadata": {}, "outputs": [], "source": [ "## compute attention score ##" ] }, { "cell_type": "code", "execution_count": 9, "id": "c88d9d8e", "metadata": {}, "outputs": [], "source": [ "dna_in = dna_in.to(cuda)\n", "sig_in = atac_in.to(cuda)\n", "\n", "pos1 = torch.ones(129, dtype=int).to(cuda)\n", "for i in range (len(pos1)):\n", " pos1[i]+=i\n", "pos2 = torch.ones(8, dtype=int).to(cuda)\n", "for i in range (len(pos2)):\n", " pos2[i]+=i\n", "pos3 = torch.zeros(150, dtype=int).to(cuda)\n", "pos3[tss_loc] = 0\n", "if direction=='+':\n", " pos3[tss_loc-1] = 1\n", " pos3[tss_loc+1] = 2\n", " for tss_i in range (tss_loc-1):\n", " pos3[tss_loc-1-tss_i-1] = pos3[tss_loc-1-tss_i]+2\n", " for tss_i in range (dna_in.shape[0]-tss_loc-2):\n", " pos3[tss_loc+1+tss_i+1] = pos3[tss_loc+1+tss_i]+2\n", "if direction=='-':\n", " pos3[tss_loc-1] = 2\n", " pos3[tss_loc+1] = 1\n", " for tss_i in range (tss_loc-1):\n", " pos3[tss_loc-1-tss_i-1] = pos3[tss_loc-1-tss_i]+2\n", " for tss_i in range (dna_in.shape[0]-tss_loc-2):\n", " pos3[tss_loc+1+tss_i+1] = pos3[tss_loc+1+tss_i]+2\n", " \n", "with torch.no_grad():\n", " \n", " CLS = dna_embed(torch.ones(dna_in.shape[0]*8, 1, dtype=int).to(cuda))\n", " x_POS_1 = pos1_embed(pos1) \n", " x_mul = dna_embed(dna_in.int().reshape(dna_in.shape[0]*8, 128)) + atac_embed(sig_in.int().reshape(dna_in.shape[0]*8, 128)) \n", " x_embed = torch.cat((CLS, x_mul), dim=1)\n", " x_enc_1 = encoder_1(x_embed + x_POS_1)[:,0,:].reshape(dna_in.shape[0],8,2048)\n", " x_POS_2 = pos2_embed(pos2)\n", " x_enc_2 = encoder_2(x_enc_1+x_POS_2)\n", " x_enc_2 = rearrange(x_enc_2, 'b n d -> b d n')\n", " x_enc_2 = atten_pool(x_enc_2)\n", " x_enc_2 = rearrange(x_enc_2, 'b d n -> b n d').squeeze(1)\n", " x_pad = pad_embed(torch.zeros(150-x_enc_2.shape[0], dtype=int).to(cuda))\n", " x_eb3 = torch.cat((x_enc_2, x_pad), dim=0)\n", " x_POS_3 = tss_embed(pos3)\n", " x_enc_3 = encoder_3(x_eb3 + x_POS_3)\n", " result = ff_net(x_enc_3).squeeze(1)" ] }, { "cell_type": "code", "execution_count": 10, "id": "48b55cf9", "metadata": {}, "outputs": [], "source": [ "## print results #" ] }, { "cell_type": "code", "execution_count": 11, "id": "cbfd7435", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([6.4337], device='cuda:0', dtype=torch.bfloat16)\n" ] } ], "source": [ "print(result)" ] }, { "cell_type": "code", "execution_count": 12, "id": "4e46b270", "metadata": {}, "outputs": [], "source": [ "## Here, we additionally showed the KO of chrX_48761672_48762478 (enhancer)\n", "## In the input .txt file (data/info_ko_ko.txt), we set the ATAC signal in that peak as zero" ] }, { "cell_type": "code", "execution_count": null, "id": "df3e7eae", "metadata": {}, "outputs": [], "source": [ "## Please notice that if you are performing KI of functional elements, the index of TSS location might change ##\n", "## Please re-specify your TSS location in KI experiments ##\n", "tss_loc = 20" ] }, { "cell_type": "code", "execution_count": 13, "id": "8c012ee3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([44, 1024]) torch.Size([44, 1024])\n" ] } ], "source": [ "raw_input_data = narrowPeak_Reader('../data/info_ko_ko.txt')\n", "\n", "data_len = int(len(raw_input_data)/3)\n", "peak_id = list(range(data_len))\n", "dna_in = list(range(data_len))\n", "atac_in = list(range(data_len))\n", "\n", "for i_data in range (data_len): \n", " peak_id[i_data] = raw_input_data[i_data*3][0]\n", " dna_in[i_data] = pre_processing(tokenizer(raw_input_data[i_data*3+1][0], 6), genome_dict)\n", " for i in range (len(raw_input_data[i_data*3+2])):\n", " raw_input_data[i_data*3+2][i]=float(raw_input_data[i_data*3+2][i])\n", " atac_in[i_data] = pre_pro(raw_input_data[i_data*3+2], 6)\n", "dna_in = torch.tensor(dna_in)\n", "atac_in = torch.tensor(atac_in)\n", "\n", "print(dna_in.shape, atac_in.shape)" ] }, { "cell_type": "code", "execution_count": 14, "id": "a1814d9e", "metadata": {}, "outputs": [], "source": [ "dna_in = dna_in.to(cuda)\n", "sig_in = atac_in.to(cuda)\n", "\n", "pos1 = torch.ones(129, dtype=int).to(cuda)\n", "for i in range (len(pos1)):\n", " pos1[i]+=i\n", "pos2 = torch.ones(8, dtype=int).to(cuda)\n", "for i in range (len(pos2)):\n", " pos2[i]+=i\n", "pos3 = torch.zeros(150, dtype=int).to(cuda)\n", "pos3[tss_loc] = 0\n", "if direction=='+':\n", " pos3[tss_loc-1] = 1\n", " pos3[tss_loc+1] = 2\n", " for tss_i in range (tss_loc-1):\n", " pos3[tss_loc-1-tss_i-1] = pos3[tss_loc-1-tss_i]+2\n", " for tss_i in range (dna_in.shape[0]-tss_loc-2):\n", " pos3[tss_loc+1+tss_i+1] = pos3[tss_loc+1+tss_i]+2\n", "if direction=='-':\n", " pos3[tss_loc-1] = 2\n", " pos3[tss_loc+1] = 1\n", " for tss_i in range (tss_loc-1):\n", " pos3[tss_loc-1-tss_i-1] = pos3[tss_loc-1-tss_i]+2\n", " for tss_i in range (dna_in.shape[0]-tss_loc-2):\n", " pos3[tss_loc+1+tss_i+1] = pos3[tss_loc+1+tss_i]+2\n", " \n", "with torch.no_grad():\n", " \n", " CLS = dna_embed(torch.ones(dna_in.shape[0]*8, 1, dtype=int).to(cuda))\n", " x_POS_1 = pos1_embed(pos1) \n", " x_mul = dna_embed(dna_in.int().reshape(dna_in.shape[0]*8, 128)) + atac_embed(sig_in.int().reshape(dna_in.shape[0]*8, 128)) \n", " x_embed = torch.cat((CLS, x_mul), dim=1)\n", " x_enc_1 = encoder_1(x_embed + x_POS_1)[:,0,:].reshape(dna_in.shape[0],8,2048)\n", " x_POS_2 = pos2_embed(pos2)\n", " x_enc_2 = encoder_2(x_enc_1+x_POS_2)\n", " x_enc_2 = rearrange(x_enc_2, 'b n d -> b d n')\n", " x_enc_2 = atten_pool(x_enc_2)\n", " x_enc_2 = rearrange(x_enc_2, 'b d n -> b n d').squeeze(1)\n", " x_pad = pad_embed(torch.zeros(150-x_enc_2.shape[0], dtype=int).to(cuda))\n", " x_eb3 = torch.cat((x_enc_2, x_pad), dim=0)\n", " x_POS_3 = tss_embed(pos3)\n", " x_enc_3 = encoder_3(x_eb3 + x_POS_3)\n", " result = ff_net(x_enc_3).squeeze(1)" ] }, { "cell_type": "code", "execution_count": 15, "id": "dd0281e9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([2.2318], device='cuda:0', dtype=torch.bfloat16)\n" ] } ], "source": [ "print(result)" ] }, { "cell_type": "code", "execution_count": null, "id": "a5ee67c1", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }