in src/deep_baselines/cheer.py [0:0]
def forward(self, x, lengths=None, labels=None):
B, L = x.shape
x = self.embedding(x)
# B, Len, dim
x = x.reshape((B, self.channel_in, -1, self.embedding_dim))
x1 = x[:, 0, :, :].reshape(B, 1, -1, self.embedding_dim)
x2 = x[:, 1, :, :].reshape(B, 1, -1, self.embedding_dim)
x3 = x[:, 2, :, :].reshape(B, 1, -1, self.embedding_dim)
x4 = x[:, 3, :, :].reshape(B, 1, -1, self.embedding_dim)
x5 = x[:, 4, :, :].reshape(B, 1, -1, self.embedding_dim)
x6 = x[:, 5, :, :].reshape(B, 1, -1, self.embedding_dim)
#
x1 = [F.relu(conv(x1)).squeeze(3) for conv in self.convs1]
x2 = [F.relu(conv(x2)).squeeze(3) for conv in self.convs2]
x3 = [F.relu(conv(x3)).squeeze(3) for conv in self.convs3]
x4 = [F.relu(conv(x4)).squeeze(3) for conv in self.convs4]
x5 = [F.relu(conv(x5)).squeeze(3) for conv in self.convs5]
x6 = [F.relu(conv(x6)).squeeze(3) for conv in self.convs6]
x1 = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x1]
x2 = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x2]
x3 = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x3]
x4 = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x4]
x5 = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x5]
x6 = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x6]
x1 = torch.cat(x1, 1)
x2 = torch.cat(x2, 1)
x3 = torch.cat(x3, 1)
x4 = torch.cat(x4, 1)
x5 = torch.cat(x5, 1)
x6 = torch.cat(x6, 1)
x = torch.cat((x1, x2, x3, x4, x5, x6), 1)
x = self.dropout(x)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
logits = self.linear_layer(x)
if self.output:
output = self.output(logits)
else:
output = logits
outputs = [logits, output]
if labels is not None:
if self.output_mode in ["regression"]:
loss = self.loss_fct(logits.view(-1), labels.view(-1))
elif self.output_mode in ["multi_label", "multi-label"]:
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float())
elif self.output_mode in ["binary_class", "binary-class"]:
loss = self.loss_fct(logits.view(-1), labels.view(-1).float())
elif self.output_mode in ["multi_class", "multi-class"]:
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float())
outputs = [loss, *outputs]
return outputs