In [1]:
import torch
import torch.nn as nn
import numpy as np
import math
from einops.layers.torch import Rearrange
from einops import rearrange
import torch.nn.functional as F

In [2]:
## Data processing ##

## We provided an example data (data/info_attention.txt) for reference. ##

## Please provide the following information in that file: 
## (1) n x 1029bp DNA sequences, n is the number of input peaks. For each peak, the central nucleotide should be sampled at the central of that peak, and the DNA sequence length should be of 1029 bp;
## (2) n x 1029bp ATAC signal values. The ATAC signal values should be from the .bigWig file, each value is at the same position of your sampled nucleotide.
## (3) Please give n x IDs for reference of each peak

## Please write your information in a .txt file, for example:
## line 1 (ID of this peak): >chr12 135145-135516 
## line 2 (1029bp DNA): ATCGATCG ... ... TCGA
## line 3 (1029bp ATAC): 1.28971 1.11121 ... ... 0.01234
## line 4 (next peak ID): ... ...
## ... ...
## input information of n peaks should have 3*n lines

In [3]:
from data_processing import *
genome_dict = torch.load('../data/kmer_dict.pkl')

raw_input_data = narrowPeak_Reader('../data/info_attention.txt')

data_len = int(len(raw_input_data)/3)
peak_id  = list(range(data_len))
dna_in   = list(range(data_len))
atac_in  = list(range(data_len))

for i_data in range (data_len):    
    peak_id[i_data] = raw_input_data[i_data*3][0]
    dna_in[i_data]  = pre_processing(tokenizer(raw_input_data[i_data*3+1][0], 6), genome_dict)
    for i in range (len(raw_input_data[i_data*3+2])):
        raw_input_data[i_data*3+2][i]=float(raw_input_data[i_data*3+2][i])
    atac_in[i_data] = pre_pro(raw_input_data[i_data*3+2], 6)
dna_in  = torch.tensor(dna_in)
atac_in = torch.tensor(atac_in)

print(dna_in.shape, atac_in.shape)

torch.Size([44, 1024]) torch.Size([44, 1024])


In [4]:
## Specify your TSS location and strand ##
## Our example gene is GATA1, the closest ATAC peak to TSS is the 20th peak, and the gene is + strand ##
tss_loc   = 20
direction = '+'

In [5]:
## Load REformer model ##

In [6]:
from models import *
from attention import *

cuda = torch.device('cuda', 0)

dna_embed  = torch.load("../pretrained_models/dna_embed.pkl"    , map_location='cpu').to(cuda)
atac_embed = torch.load("../pretrained_models/atac_embed.pkl"   , map_location='cpu').to(cuda)
pos1_embed = torch.load("../pretrained_models/pos1_embed.pkl"   , map_location='cpu').to(cuda)
pos2_embed = torch.load("../pretrained_models/pos2_embed.pkl"   , map_location='cpu').to(cuda)
tss_embed  = torch.load("../pretrained_models/tss_embed.pkl"    , map_location='cpu').to(cuda)
pad_embed  = torch.load("../pretrained_models/pad_embed.pkl"    , map_location='cpu').to(cuda)
encoder_1  = torch.load("../pretrained_models/transformer_1.pkl", map_location='cpu').to(cuda)
encoder_2  = torch.load("../pretrained_models/transformer_2.pkl", map_location='cpu').to(cuda)
encoder_3  = torch.load("../pretrained_models/transformer_3.pkl", map_location='cpu').to(cuda)
atten_pool = torch.load("../pretrained_models/atten_pool.pkl"   , map_location='cpu').to(cuda)
ff_net     = torch.load("../pretrained_models/feedforward.pkl"  , map_location='cpu').to(cuda)

In [7]:
## compute attention score ##

In [8]:
dna_in = dna_in.to(cuda)
sig_in  = atac_in.to(cuda)

pos1 = torch.ones(129, dtype=int).to(cuda)
for i in range (len(pos1)):
    pos1[i]+=i
pos2 = torch.ones(8, dtype=int).to(cuda)
for i in range (len(pos2)):
    pos2[i]+=i
pos3 = torch.zeros(150, dtype=int).to(cuda)
pos3[tss_loc] = 0
if direction=='+':
    pos3[tss_loc-1] = 1
    pos3[tss_loc+1] = 2
    for tss_i in range (tss_loc-1):
        pos3[tss_loc-1-tss_i-1] = pos3[tss_loc-1-tss_i]+2
    for tss_i in range (dna_in.shape[0]-tss_loc-2):
        pos3[tss_loc+1+tss_i+1] = pos3[tss_loc+1+tss_i]+2
if direction=='-':
    pos3[tss_loc-1] = 2
    pos3[tss_loc+1] = 1
    for tss_i in range (tss_loc-1):
        pos3[tss_loc-1-tss_i-1] = pos3[tss_loc-1-tss_i]+2
    for tss_i in range (dna_in.shape[0]-tss_loc-2):
        pos3[tss_loc+1+tss_i+1] = pos3[tss_loc+1+tss_i]+2
        
with torch.no_grad():
        
    CLS     = dna_embed(torch.ones(dna_in.shape[0]*8, 1, dtype=int).to(cuda))
    x_POS_1 = pos1_embed(pos1)      
    x_mul   = dna_embed(dna_in.int().reshape(dna_in.shape[0]*8, 128)) + atac_embed(sig_in.int().reshape(dna_in.shape[0]*8, 128))        
    x_embed = torch.cat((CLS, x_mul), dim=1)
    x_enc_1 = encoder_1(x_embed + x_POS_1)[:,0,:].reshape(dna_in.shape[0],8,2048)
    x_POS_2 = pos2_embed(pos2)
    x_enc_2 = encoder_2(x_enc_1+x_POS_2)
    x_enc_2 = rearrange(x_enc_2, 'b n d -> b d n')
    x_enc_2 = atten_pool(x_enc_2)
    x_enc_2 = rearrange(x_enc_2, 'b d n -> b n d').squeeze(1)
    x_pad   = pad_embed(torch.zeros(150-x_enc_2.shape[0], dtype=int).to(cuda))
    x_eb3   = torch.cat((x_enc_2, x_pad), dim=0)
    x_POS_3 = tss_embed(pos3)
    x_enc_3 = x_eb3 + x_POS_3

    attn_probs = extract_selfattention_maps(encoder_3.encoder,x_enc_3.unsqueeze(0))
    SM = nn.Softmax(dim=2)
    attention_score = SM(attn_probs[0]).mean(0).sum(0)[0:dna_in.shape[0]].tolist()

In [9]:
## print results ##

In [10]:
for i in range (len(attention_score)):    
    print(peak_id[i])
    print('Attention score: ', attention_score[i])

>chrX_48648355_48648753
Attention score:  1.028488039970398
>chrX_48652218_48652494
Attention score:  1.3975310325622559
>chrX_48660471_48661224
Attention score:  1.28517746925354
>chrX_48675912_48677365
Attention score:  1.1235225200653076
>chrX_48680543_48680748
Attention score:  1.3664084672927856
>chrX_48683426_48683886
Attention score:  1.525336503982544
>chrX_48689104_48689575
Attention score:  0.9983937740325928
>chrX_48695968_48697301
Attention score:  0.9551467299461365
>chrX_48701826_48702668
Attention score:  1.094896674156189
>chrX_48737081_48738023
Attention score:  1.2857067584991455
>chrX_48750823_48751127
Attention score:  1.350295066833496
>chrX_48753382_48754579
Attention score:  1.7550113201141357
>chrX_48761672_48762478
Attention score:  1.7182071208953857
>chrX_48765487_48765941
Attention score:  1.5139516592025757
>chrX_48770352_48771155
Attention score:  1.1152039766311646
>chrX_48776735_48777456
Attention score:  1.668673038482666
>chrX_48779520_48779697
Attenti