in tbsm_pytorch.py [0:0]
def forward(self, x, lS_o, lS_i):
# Move offsets to device if needed and not already done.
if args.run_fast and not self.offsets_moved:
self.offsets = self.offsets.to(x[0].device)
self.offsets_moved = True
# data point is history H and last entry w
n = x[0].shape[0] # batch_size
ts = len(x)
H = torch.zeros(n, self.ts_length, self.ln_top[-1]).to(x[0].device)
# Compute H using either fast or original approach depending on args.run_fast.
if args.run_fast:
# j determines access indices of input; first, determine j bounds and get all inputs.
j_lower = (ts - self.ts_length - 1)
j_upper = (ts - 1)
# Concatenate x[j]s using j bounds.
concatenated_x = torch.cat(x[j_lower : j_upper])
# Set offsets and increase size if needed.
curr_max_offset = (j_upper - j_lower) * n
if curr_max_offset > self.max_offset + 1:
# Resize offsets to 2x required size.
self.offsets = torch.tensor(list(range(curr_max_offset * 2))).to(self.offsets.device)
self.max_offset = curr_max_offset * 2
concatenated_lS_o = [self.offsets[: curr_max_offset] for j in range(len(lS_o[0]))]
# Concatenate lS_i[0, 1, 2]s.
concatenated_lS_i = [torch.cat([lS_i[i][j] for i in range(j_lower, j_upper)]) for j in range(len(lS_i[0]))]
# oj determines access indices of output; determine oj bounds to assign output values in H. oj is just j indices adjusted to start at 0.
oj_lower = 0 - (ts - self.ts_length - 1)
oj_upper = (ts - 1) - (ts - self.ts_length - 1)
# After fetching all inputs, run through DLRM.
concatenated_dlrm_output = self.dlrm(concatenated_x, concatenated_lS_o, concatenated_lS_i)
# Reshape output with new TS dimension and transpose to get H output.
transposed_concatenated_dlrm_output = torch.transpose(concatenated_dlrm_output.reshape((j_upper - j_lower), n, self.ln_top[-1]), 0, 1)
if self.model_type == "tsl" and self.tsl_proj:
dlrm_output = Functional.normalize(transposed_concatenated_dlrm_output, p=2, dim=2)
else:
dlrm_output = transposed_concatenated_dlrm_output
# Assign the output to H with correct output bounds.
H[:, oj_lower : oj_upper, :] = dlrm_output
else:
# split point into first part (history)
# and last item
for j in range(ts - self.ts_length - 1, ts - 1):
oj = j - (ts - self.ts_length - 1)
v = self.dlrm(x[j], lS_o[j], lS_i[j])
if self.model_type == "tsl" and self.tsl_proj:
v = Functional.normalize(v, p=2, dim=1)
H[:, oj, :] = v
w = self.dlrm(x[-1], lS_o[-1], lS_i[-1])
# project onto sphere
if self.model_type == "tsl" and self.tsl_proj:
w = Functional.normalize(w, p=2, dim=1)
# print("data: ", x[-1], lS_o[-1], lS_i[-1])
(mini_batch_size, _) = w.shape
# for cases when model is tsl or mha
if self.model_type != "rnn":
# create MLP for each TSL component
# each ams[] element is one component
for j in range(self.num_mlps):
ts = self.ts_length - self.ts_array[j]
c = self.ams[j](w, H[:, ts:, :])
c = torch.reshape(c, (mini_batch_size, -1))
# concat context and w
z = torch.cat([c, w], dim=1)
# obtain probability of a click as a result of MLP
p = dlrm.DLRM_Net().apply_mlp(z, self.mlps[j])
if j == 0:
ps = p
else:
ps = torch.cat((ps, p), dim=1)
if ps.shape[1] > 1:
p_out = dlrm.DLRM_Net().apply_mlp(ps, self.final_mlp)
else:
p_out = ps
# RNN based on LSTM cells case, context is final hidden state
else:
hidden_dim = w.shape[1] # equal to dim(w) = dim(c)
level = self.rnn_num_layers # num stacks of rnns
Ht = H.permute(1, 0, 2)
rnn = nn.LSTM(int(self.ln_top[-1]), int(hidden_dim),
int(level)).to(x[0].device)
h0 = torch.randn(level, n, hidden_dim).to(x[0].device)
c0 = torch.randn(level, n, hidden_dim).to(x[0].device)
output, (hn, cn) = rnn(Ht, (h0, c0))
hn, cn = torch.squeeze(hn[level - 1, :, :]), \
torch.squeeze(cn[level - 1, :, :])
if self.debug_mode:
print(w.shape, output.shape, hn.shape)
# concat context and w
z = torch.cat([hn, w], dim=1)
p_out = dlrm.DLRM_Net().apply_mlp(z, self.mlps[0])
return p_out