in safetorch/safe_network.py [0:0]
def __init__(self, config):
super(SAFE, self).__init__()
self.conf = config
self.instructions_embeddings = torch.nn.Embedding(
self.conf.num_embeddings, self.conf.embedding_size
)
self.bidirectional_rnn = torch.nn.GRU(
input_size=self.conf.embedding_size,
hidden_size=self.conf.rnn_state_size,
num_layers=self.conf.rnn_depth,
bias=True,
batch_first=True,
dropout=0,
bidirectional=True,
)
self.WS1 = Parameter(
torch.Tensor(self.conf.attention_depth, 2 * self.conf.rnn_state_size)
)
self.WS2 = Parameter(
torch.Tensor(self.conf.attention_hops, self.conf.attention_depth)
)
self.dense_1 = torch.nn.Linear(
2 * self.conf.attention_hops * self.conf.rnn_state_size,
self.conf.dense_layer_size,
bias=True,
)
self.dense_2 = torch.nn.Linear(
self.conf.dense_layer_size, self.conf.embedding_size, bias=True
)