def forward()

in tbsm_pytorch.py [0:0]


    def forward(self, x, lS_o, lS_i):
        # Move offsets to device if needed and not already done.
        if args.run_fast and not self.offsets_moved:
            self.offsets = self.offsets.to(x[0].device)
            self.offsets_moved = True

        # data point is history H and last entry w
        n = x[0].shape[0]  # batch_size
        ts = len(x)
        H = torch.zeros(n, self.ts_length, self.ln_top[-1]).to(x[0].device)

        # Compute H using either fast or original approach depending on args.run_fast.
        if args.run_fast:
            # j determines access indices of input; first, determine j bounds and get all inputs.
            j_lower = (ts - self.ts_length - 1)
            j_upper = (ts - 1)

            # Concatenate x[j]s using j bounds.
            concatenated_x = torch.cat(x[j_lower : j_upper])

            # Set offsets and increase size if needed.
            curr_max_offset = (j_upper - j_lower) * n
            if curr_max_offset > self.max_offset + 1:
                # Resize offsets to 2x required size.
                self.offsets = torch.tensor(list(range(curr_max_offset * 2))).to(self.offsets.device)
                self.max_offset = curr_max_offset * 2

            concatenated_lS_o = [self.offsets[: curr_max_offset] for j in range(len(lS_o[0]))]

            # Concatenate lS_i[0, 1, 2]s.
            concatenated_lS_i = [torch.cat([lS_i[i][j] for i in range(j_lower, j_upper)]) for j in range(len(lS_i[0]))]

            # oj determines access indices of output; determine oj bounds to assign output values in H. oj is just j indices adjusted to start at 0.
            oj_lower = 0 - (ts - self.ts_length - 1)
            oj_upper = (ts - 1) - (ts - self.ts_length - 1)

            # After fetching all inputs, run through DLRM.
            concatenated_dlrm_output = self.dlrm(concatenated_x, concatenated_lS_o, concatenated_lS_i)

            # Reshape output with new TS dimension and transpose to get H output.
            transposed_concatenated_dlrm_output = torch.transpose(concatenated_dlrm_output.reshape((j_upper - j_lower), n, self.ln_top[-1]), 0, 1)
            if self.model_type == "tsl" and self.tsl_proj:
                dlrm_output = Functional.normalize(transposed_concatenated_dlrm_output, p=2, dim=2)
            else:
                dlrm_output = transposed_concatenated_dlrm_output

            # Assign the output to H with correct output bounds.
            H[:, oj_lower : oj_upper, :] = dlrm_output

        else:
            # split point into first part (history)
            # and last item
            for j in range(ts - self.ts_length - 1, ts - 1):
                oj = j - (ts - self.ts_length - 1)
                v = self.dlrm(x[j], lS_o[j], lS_i[j])
                if self.model_type == "tsl" and self.tsl_proj:
                    v = Functional.normalize(v, p=2, dim=1)
                H[:, oj, :] = v

        w = self.dlrm(x[-1], lS_o[-1], lS_i[-1])
        # project onto sphere
        if self.model_type == "tsl" and self.tsl_proj:
            w = Functional.normalize(w, p=2, dim=1)
        # print("data: ", x[-1], lS_o[-1], lS_i[-1])

        (mini_batch_size, _) = w.shape

        # for cases when model is tsl or mha
        if self.model_type != "rnn":

            # create MLP for each TSL component
            # each ams[] element is one component
            for j in range(self.num_mlps):

                ts = self.ts_length - self.ts_array[j]
                c = self.ams[j](w, H[:, ts:, :])
                c = torch.reshape(c, (mini_batch_size, -1))
                # concat context and w
                z = torch.cat([c, w], dim=1)
                # obtain probability of a click as a result of MLP
                p = dlrm.DLRM_Net().apply_mlp(z, self.mlps[j])
                if j == 0:
                    ps = p
                else:
                    ps = torch.cat((ps, p), dim=1)

            if ps.shape[1] > 1:
                p_out = dlrm.DLRM_Net().apply_mlp(ps, self.final_mlp)
            else:
                p_out = ps

        # RNN based on LSTM cells case, context is final hidden state
        else:
            hidden_dim = w.shape[1]     # equal to dim(w) = dim(c)
            level = self.rnn_num_layers  # num stacks of rnns
            Ht = H.permute(1, 0, 2)
            rnn = nn.LSTM(int(self.ln_top[-1]), int(hidden_dim),
            int(level)).to(x[0].device)
            h0 = torch.randn(level, n, hidden_dim).to(x[0].device)
            c0 = torch.randn(level, n, hidden_dim).to(x[0].device)
            output, (hn, cn) = rnn(Ht, (h0, c0))
            hn, cn = torch.squeeze(hn[level - 1, :, :]), \
                torch.squeeze(cn[level - 1, :, :])
            if self.debug_mode:
                print(w.shape, output.shape, hn.shape)
            # concat context and w
            z = torch.cat([hn, w], dim=1)
            p_out = dlrm.DLRM_Net().apply_mlp(z, self.mlps[0])

        return p_out