def forward()

in tbsm_pytorch.py [0:0]


    def forward(self, x=None, H=None):
        # adjust input shape
        (batchSize, vector_dim) = x.shape
        x = torch.reshape(x, (batchSize, 1, -1))
        x = torch.transpose(x, 1, 2)
        # debug prints
        # print("shapes: ", self.A.shape, x.shape)

        # perform mode operation
        if self.model_type == "tsl":
            if self.tsl_inner == "def":
                ax = torch.matmul(self.A, x)
                x = torch.matmul(self.A.permute(0, 2, 1), ax)
                # debug prints
                # print("shapes: ", H.shape, ax.shape, x.shape)
            elif self.tsl_inner == "ind":
                x = torch.matmul(self.A, x)

            # perform interaction operation
            if self.arch_interaction_op == 'dot':
                if self.arch_attention_mechanism == 'mul':
                    # coefficients
                    a = torch.transpose(torch.bmm(H, x), 1, 2)
                    # context
                    c = torch.bmm(a, H)
                elif self.arch_attention_mechanism == 'mlp':
                    # coefficients
                    a = torch.transpose(torch.bmm(H, x), 1, 2)
                    # MLP first/last layer dims are automatically adjusted to ts_length
                    y = dlrm.DLRM_Net().apply_mlp(a, self.mlp)
                    # context, y = mlp(a)
                    c = torch.bmm(torch.reshape(y, (batchSize, 1, -1)), H)
                else:
                    sys.exit('ERROR: --arch-attention-mechanism='
                        + self.arch_attention_mechanism + ' is not supported')

            else:
                sys.exit('ERROR: --arch-interaction-op=' + self.arch_interaction_op
                    + ' is not supported')

        elif self.model_type == "mha":
            x = torch.transpose(x, 1, 2)
            Qx = torch.transpose(torch.matmul(x, self.Q), 0, 1)
            HK = torch.transpose(torch.matmul(H, self.K), 0, 1)
            HV = torch.transpose(torch.matmul(H, self.V), 0, 1)
            # multi-head attention (mha)
            multihead_attn = nn.MultiheadAttention(self.emb_m, self.nheads).to(x.device)
            attn_output, _ = multihead_attn(Qx, HK, HV)
            # context
            c = torch.squeeze(attn_output, dim=0)
            # debug prints
            # print("shapes:", c.shape, Qx.shape)

        return c