def forward()

in src/deep_baselines/cheer.py [0:0]


    def forward(self, x, lengths=None, labels=None):
        B, L = x.shape
        x = self.embedding(x)
        # B, Len, dim
        x = x.reshape((B, self.channel_in, -1, self.embedding_dim))
        x1 = x[:, 0, :, :].reshape(B, 1, -1, self.embedding_dim)
        x2 = x[:, 1, :, :].reshape(B, 1, -1, self.embedding_dim)
        x3 = x[:, 2, :, :].reshape(B, 1, -1, self.embedding_dim)
        x4 = x[:, 3, :, :].reshape(B, 1, -1, self.embedding_dim)
        x5 = x[:, 4, :, :].reshape(B, 1, -1, self.embedding_dim)
        x6 = x[:, 5, :, :].reshape(B, 1, -1, self.embedding_dim)
        #
        x1 = [F.relu(conv(x1)).squeeze(3) for conv in self.convs1]
        x2 = [F.relu(conv(x2)).squeeze(3) for conv in self.convs2]
        x3 = [F.relu(conv(x3)).squeeze(3) for conv in self.convs3]
        x4 = [F.relu(conv(x4)).squeeze(3) for conv in self.convs4]
        x5 = [F.relu(conv(x5)).squeeze(3) for conv in self.convs5]
        x6 = [F.relu(conv(x6)).squeeze(3) for conv in self.convs6]

        x1 = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x1]
        x2 = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x2]
        x3 = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x3]
        x4 = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x4]
        x5 = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x5]
        x6 = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x6]

        x1 = torch.cat(x1, 1)
        x2 = torch.cat(x2, 1)
        x3 = torch.cat(x3, 1)
        x4 = torch.cat(x4, 1)
        x5 = torch.cat(x5, 1)
        x6 = torch.cat(x6, 1)

        x = torch.cat((x1, x2, x3, x4, x5, x6), 1)

        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.linear_layer(x)

        if self.output:
            output = self.output(logits)
        else:
            output = logits

        outputs = [logits, output]
        if labels is not None:
            if self.output_mode in ["regression"]:
                loss = self.loss_fct(logits.view(-1), labels.view(-1))
            elif self.output_mode in ["multi_label", "multi-label"]:
                loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float())
            elif self.output_mode in ["binary_class", "binary-class"]:
                loss = self.loss_fct(logits.view(-1), labels.view(-1).float())
            elif self.output_mode in ["multi_class", "multi-class"]:
                loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float())
            outputs = [loss, *outputs]

        return outputs