codes/models.py (109 lines of code) (raw):
import pandas as pd
import torch
from torch import nn
from einops.layers.torch import Rearrange
from einops import rearrange
import math
class AttentionPool(nn.Module):
## Attention pooling block ##
def __init__(self, dim, pool_size = 8):
super().__init__()
self.pool_size = pool_size
self.pool_fn = Rearrange('b d (n p) -> b d n p', p = pool_size)
self.to_attn_logits = nn.Conv2d(dim, dim, 1, bias = False)
nn.init.dirac_(self.to_attn_logits.weight)
with torch.no_grad():
self.to_attn_logits.weight.mul_(2)
def forward(self, x):
b, _, n = x.shape
remainder = n % self.pool_size
needs_padding = remainder > 0
if needs_padding:
x = F.pad(x, (0, remainder), value = 0)
mask = torch.zeros((b, 1, n), dtype = torch.bool, device = x.device)
mask = F.pad(mask, (0, remainder), value = True)
x = self.pool_fn(x)
logits = self.to_attn_logits(x)
if needs_padding:
mask_value = -torch.finfo(logits.dtype).max
logits = logits.masked_fill(self.pool_fn(mask), mask_value)
attn = logits.softmax(dim = -1)
return (x * attn).sum(dim = -1)
class DNA_Embedding(nn.Module):
## DNA embedding layer ##
def __init__(self):
super(DNA_Embedding, self).__init__()
dim = 2048
self.dna_embed = nn.Embedding(4100, dim)
def forward(self, DNA):
Genome_embed = self.dna_embed(DNA)
return Genome_embed
class Sig_Embedding(nn.Module):
## ATAC embedding layer ##
def __init__(self):
super(Sig_Embedding, self).__init__()
dim = 2048
self.sig_embed = nn.Embedding(38, dim)
def forward(self, signal):
signal_embed = self.sig_embed(signal)
return signal_embed
class Encoder(nn.Module):
## transformer encoder blocks for Transformer-1, 2 in CREformer-Elementary, and Transformer in CREformer-Regulatory ##
def __init__(self, d_model=2048, batch_first=True, nhead=32, dim_ffn=2048*4, num_layer=20, drop=0, LNM=1e-05):
super(Encoder, self).__init__()
self.norm = nn.LayerNorm(d_model)
self.encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead,
dim_feedforward=dim_ffn,
batch_first=True,dropout=drop, layer_norm_eps=LNM)
self.encoder = nn.TransformerEncoder(
self.encoder_layer,
num_layers=num_layer)
def forward(self, x):
output = self.encoder(self.norm(x))
return output
class Encoder1(nn.Module):
def __init__(self, d_model=2048, batch_first=True, nhead=32, dim_ffn=2048*4, num_layer=20, drop=0, LNM=1e-05):
super(Encoder1, self).__init__()
self.norm = nn.LayerNorm(d_model)
self.encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead,
dim_feedforward=dim_ffn,
batch_first=True,dropout=drop, layer_norm_eps=LNM)
self.encoder = nn.TransformerEncoder(
self.encoder_layer,
num_layers=num_layer)
def forward(self, x):
output = self.encoder(self.norm(x))
return output
class Pos_L1_Embed(nn.Module):
## Position-1 embedding layer ##
def __init__(self):
super(Pos_L1_Embed, self).__init__()
dim = 2048
self.pos_embed = nn.Embedding(130, dim)
def forward(self, Position):
position_embed = self.pos_embed(Position)
return position_embed
class Pos_L2_Embed(nn.Module):
## Position-2 embedding layer ##
def __init__(self):
super(Pos_L2_Embed, self).__init__()
dim = 2048
self.pos_embed = nn.Embedding(129, dim)
def forward(self, Position):
position_embed = self.pos_embed(Position)
return position_embed
class Pos_L3_Embed(nn.Module):
## TSS-distance embedding layer ##
def __init__(self, max_len):
super(Pos_L3_Embed, self).__init__()
dim = 2048
self.pos_embed = nn.Embedding(max_len, dim)
def forward(self, Position):
position_embed = self.pos_embed(Position)
return position_embed
class ANN(nn.Module):
## Feed forward layer ##
def __init__(self):
super(ANN, self).__init__()
self.l1 = nn.Linear(1 * 150 * 2048 , 32)
self.l2 = nn.Linear(32 , 1)
self.act= nn.ReLU()
def forward(self, x1):
x1 = x1.view(-1, 1 * 150 * 2048)
x1 = self.act(self.l1(x1))
x2 = self.l2(x1)
return x2