in tbsm_pytorch.py [0:0]
def forward(self, x=None, H=None):
# adjust input shape
(batchSize, vector_dim) = x.shape
x = torch.reshape(x, (batchSize, 1, -1))
x = torch.transpose(x, 1, 2)
# debug prints
# print("shapes: ", self.A.shape, x.shape)
# perform mode operation
if self.model_type == "tsl":
if self.tsl_inner == "def":
ax = torch.matmul(self.A, x)
x = torch.matmul(self.A.permute(0, 2, 1), ax)
# debug prints
# print("shapes: ", H.shape, ax.shape, x.shape)
elif self.tsl_inner == "ind":
x = torch.matmul(self.A, x)
# perform interaction operation
if self.arch_interaction_op == 'dot':
if self.arch_attention_mechanism == 'mul':
# coefficients
a = torch.transpose(torch.bmm(H, x), 1, 2)
# context
c = torch.bmm(a, H)
elif self.arch_attention_mechanism == 'mlp':
# coefficients
a = torch.transpose(torch.bmm(H, x), 1, 2)
# MLP first/last layer dims are automatically adjusted to ts_length
y = dlrm.DLRM_Net().apply_mlp(a, self.mlp)
# context, y = mlp(a)
c = torch.bmm(torch.reshape(y, (batchSize, 1, -1)), H)
else:
sys.exit('ERROR: --arch-attention-mechanism='
+ self.arch_attention_mechanism + ' is not supported')
else:
sys.exit('ERROR: --arch-interaction-op=' + self.arch_interaction_op
+ ' is not supported')
elif self.model_type == "mha":
x = torch.transpose(x, 1, 2)
Qx = torch.transpose(torch.matmul(x, self.Q), 0, 1)
HK = torch.transpose(torch.matmul(H, self.K), 0, 1)
HV = torch.transpose(torch.matmul(H, self.V), 0, 1)
# multi-head attention (mha)
multihead_attn = nn.MultiheadAttention(self.emb_m, self.nheads).to(x.device)
attn_output, _ = multihead_attn(Qx, HK, HV)
# context
c = torch.squeeze(attn_output, dim=0)
# debug prints
# print("shapes:", c.shape, Qx.shape)
return c