in safetorch/safe_network.py [0:0]
def forward(self, instructions, lengths):
# for now assume a batch size of 1
batch_size = 1
# check valid input
if lengths[0] <= 0:
return torch.zeros(batch_size, self.conf.embedding_size)
# each functions is a list of embeddings id
# (an id is an index in the embedding matrix)
# with this we transform it in a list of embeddings vectors.
instructions_vectors = self.instructions_embeddings(instructions)
# consider only valid instructions (defdined by lengths)
valid_instructions = torch.split(instructions_vectors, lengths[0], 0)[0]
# We create the GRU RNN
output, h_n = self.bidirectional_rnn(valid_instructions.unsqueeze(0))
pad = torch.zeros(
1, self.conf.max_instructions - lengths[0], self.conf.embedding_size
)
# We create the matrix H
H = torch.cat((output, pad), 1)
# We do a tile to account for training batches
ws1_tiled = self.WS1.unsqueeze(0)
ws2_tiled = self.WS2.unsqueeze(0)
# we compute the matrix A
A = torch.softmax(
ws2_tiled.matmul(torch.tanh(ws1_tiled.matmul(H.transpose(1, 2)))), 2
)
# embedding matrix M
M = A.matmul(H)
# we create the flattened version of M
flattened_M = M.view(
batch_size, 2 * self.conf.attention_hops * self.conf.rnn_state_size
)
dense_1_out = F.relu(self.dense_1(flattened_M))
function_embedding = F.normalize(self.dense_2(dense_1_out), dim=1, p=2)
return function_embedding