def synthetic_experiment()

in tbsm_synthetic.py [0:0]


def synthetic_experiment():

    N, Nt, D, T = 50000, 5000, 5, 10

    auc_results = np.empty((0, 5), np.float32)

    def generate_data(N, high):
        H = np.random.uniform(low=-1.0, high=1.0, size=N * D * T).reshape(N, T, D)
        w = np.random.uniform(low=-1.0, high=1.0, size=N * D).reshape(N, 1, D)
        return H, w

    for K in range(0, 31, 10):

        print("num q terms: ", K)
        # ----- train set ------
        H, w = generate_data(N, 1.0)
        wt = np.transpose(w, (0, 2, 1))
        p = np.zeros(D * K, dtype=np.int).reshape(K, D)
        for j in range(K):
            p[j, :] = np.random.permutation(D)
            wt2 = wt[:, p[j], :]
            wt = wt + wt2
        Q = np.matmul(H[:, :, :], wt[:, :, :])  # similarity coefs
        Q = np.squeeze(Q, axis=2)
        R = np.mean(Q, axis=1)
        R = np.sign(R)
        # s1 = np.count_nonzero(R > 0)
        # print(Q.shape)
        # print("num pos, total: ", s1, N)
        R = R + 1
        t_train = R.reshape(N, 1)
        z_train = np.concatenate((H, w), axis=1)

        # ----- test set ------
        H, w = generate_data(Nt, 1.0)
        wt = np.transpose(w, (0, 2, 1))
        for j in range(K):
            wt2 = wt[:, p[j], :]
            wt = wt + wt2
        Q = np.matmul(H[:, :, :], wt[:, :, :])  # dot product
        Q = np.squeeze(Q, axis=2)
        R = np.mean(Q, axis=1)
        R = np.sign(R) + 1
        t_test = R.reshape(Nt, 1)
        z_test = np.concatenate((H, w), axis=1)
        # debug prints
        # print(z_train.shape, t_train.shape)

        class SyntheticDataset:
            def __init__(self, F, y):
                self.F = F
                self.y = y

            def __getitem__(self, index):

                if isinstance(index, slice):
                    return [
                        self[idx] for idx in range(
                            index.start or 0, index.stop or len(self), index.step or 1
                        )
                    ]

                return self.F[index], self.y[index]

            def __len__(self):
                return len(self.y)

        ztraind = SyntheticDataset(z_train, t_train)
        ztestd = SyntheticDataset(z_test, t_test)

        def collate_zfn(list_of_tuples):
            data = list(zip(*list_of_tuples))
            F = torch.tensor(data[0], dtype=torch.float)
            y = torch.tensor(data[1], dtype=torch.float)
            # y = torch.unsqueeze(y, 1)
            return F, y

        ztrain_ld = torch.utils.data.DataLoader(
            ztraind,
            batch_size=128,
            num_workers=0,
            collate_fn=collate_zfn,
            shuffle=True
        )
        ztest_ld = torch.utils.data.DataLoader(
            ztestd,
            batch_size=Nt,
            num_workers=0,
            collate_fn=collate_zfn,
        )

        ### define TBSM in PyTorch ###
        class TBSM_SubNet(nn.Module):
            def __init__(
                    self,
                    mode,
                    num_inner,
                    D,
                    T,
            ):
                super(TBSM_SubNet, self).__init__()

                self.mode = mode
                self.num_inner = num_inner
                if self.mode in ["def", "ind", "dot"]:
                    if self.mode in ["def", "ind"]:
                        self.A = []
                        mean = 0.0
                        std_dev = np.sqrt(2 / (D + D))
                        for _ in range(self.num_inner):
                            E = np.eye(D, dtype=np.float32)
                            W = np.random.normal(mean, std_dev, size=(1, D, D)) \
                                .astype(np.float32)
                            self.A.append(Parameter(torch.tensor(E + W),
                            requires_grad=True))

                    d = self.num_inner * T
                    # d = self.num_inner * D + D
                    ln_mlp = np.array([d, 2 * d, 1])
                    self.mlp = dlrm.DLRM_Net().create_mlp(ln_mlp, ln_mlp.size - 2)
                elif self.mode == "mha":
                    m = D           # dim
                    self.nheads = 8
                    self.emb_m = self.nheads * m  # mha emb dim
                    mean = 0.0
                    std_dev = np.sqrt(2 / (m + m))  # np.sqrt(1 / m) # np.sqrt(1 / n)
                    qm = np.random.normal(mean, std_dev, size=(1, m, self.emb_m)) \
                        .astype(np.float32)
                    self.Q = Parameter(torch.tensor(qm), requires_grad=True)
                    km = np.random.normal(mean, std_dev, size=(1, m, self.emb_m))  \
                        .astype(np.float32)
                    self.K = Parameter(torch.tensor(km), requires_grad=True)
                    vm = np.random.normal(mean, std_dev, size=(1, m, self.emb_m)) \
                        .astype(np.float32)
                    self.V = Parameter(torch.tensor(vm), requires_grad=True)
                    d = self.nheads * m
                    ln_mlp = np.array([d, 2 * d, 1])
                    self.mlp = dlrm.DLRM_Net().create_mlp(ln_mlp, ln_mlp.size - 2)
                else:
                    d = D * (T + 1)
                    ln_mlp = np.array([d, 2 * d, 1])
                    self.mlp = dlrm.DLRM_Net().create_mlp(ln_mlp, ln_mlp.size - 2)

            def forward(self, x):

                # H * w
                H = x[:, :-1, :]
                w = torch.unsqueeze(x[:, -1, :], dim=1)
                w = torch.transpose(w, 1, 2)
                # inner products
                if self.mode in ["def", "ind"]:
                    for j in range(self.num_inner):
                        aw = torch.matmul(self.A[j], w)
                        if self.mode == "def":
                            aw = torch.matmul(self.A[j].permute(0, 2, 1), aw)
                        a1 = torch.bmm(H, aw)
                        if j == 0:
                            z = a1
                        else:
                            z = torch.cat([z, a1], dim=1)
                    z = torch.squeeze(z, dim=2)
                elif self.mode == "dot":
                    z = torch.bmm(H, w)
                    z = torch.squeeze(z, dim=2)
                elif self.mode == "mha":
                    w = torch.transpose(w, 1, 2)
                    # print("mha shapes: ", w.shape, self.Q.shape)
                    Qx = torch.transpose(torch.matmul(w, self.Q), 0, 1)
                    HK = torch.transpose(torch.matmul(H, self.K), 0, 1)
                    HV = torch.transpose(torch.matmul(H, self.V), 0, 1)
                    multihead_attn = nn.MultiheadAttention(self.emb_m, self.nheads)
                    attn_output, _ = multihead_attn(Qx, HK, HV)
                    # print("attn shape: ", attn_output.shape)
                    z = torch.squeeze(attn_output, dim=0)
                else:  # concat
                    H = torch.flatten(H, start_dim=1, end_dim=2)
                    w = torch.flatten(w, start_dim=1, end_dim=2)
                    z = torch.cat([H, w], dim=1)
                # obtain probability of a click as a result of MLP
                p = dlrm.DLRM_Net().apply_mlp(z, self.mlp)

                return p

        def train_inner(znet):

            loss_fn = torch.nn.BCELoss(reduction="mean")
            # loss_fn = torch.nn.L1Loss(reduction="mean")
            optimizer = torch.optim.Adagrad(znet.parameters(), lr=0.05)
            # optimizer = torch.optim.SGD(znet.parameters(), lr=0.05)

            znet.train()
            nepochs = 1
            for _ in range(nepochs):
                TA = 0
                TS = 0
                for _, (X, y) in enumerate(ztrain_ld):

                    batchSize = X.shape[0]

                    # forward pass
                    Z = znet(X)

                    # loss
                    # print("Z, y: ", Z.shape, y.shape)
                    E = loss_fn(Z, y)
                    # compute loss and accuracy
                    # L = E.detach().cpu().numpy()  # numpy array
                    z = Z.detach().cpu().numpy()  # numpy array
                    t = y.detach().cpu().numpy()  # numpy array

                    # rounding t: smooth labels case
                    A = np.sum((np.round(z, 0) == np.round(t, 0)).astype(np.uint16))
                    TA += A
                    TS += batchSize

                    optimizer.zero_grad()
                    # backward pass
                    E.backward(retain_graph=True)
                    # optimizer
                    optimizer.step()
                    # if j % 500 == 0:
                    #     acc = 100.0 * TA / TS
                    #     print("j, acc: ", j, acc)
                    #     TA = 0
                    #     TS = 0

            z_final = np.zeros(Nt, dtype=np.float)
            offset = 0
            znet.eval()
            for _, (X, _) in enumerate(ztest_ld):

                batchSize = X.shape[0]
                Z = znet(X)
                z_final[offset: offset + batchSize] = \
                    np.squeeze(Z.detach().cpu().numpy(), axis=1)
                offset += batchSize
                # E = loss_fn(Z, y)
                # L = E.detach().cpu().numpy()  # numpy array

            # loss_net = L
            # print(znet.num_inner, znet.mode, ": ", loss_net)
            auc_net = 100.0 * roc_auc_score(t_test.astype(int), z_final)
            print(znet.num_inner, znet.mode, ": ", auc_net)

            return auc_net

        dim = T
        znet = TBSM_SubNet("dot", 1, D, dim)  # c or c,w
        res1 = train_inner(znet)
        znet = TBSM_SubNet("def", 1, D, dim)  # c or c,w
        res2 = train_inner(znet)
        znet = TBSM_SubNet("def", 4, D, dim)  # c or c,w
        res3 = train_inner(znet)
        znet = TBSM_SubNet("def", 8, D, dim)  # c or c,w
        res4 = train_inner(znet)
        znet = TBSM_SubNet("mha", 1, D, dim)  # c or c,w
        res5 = train_inner(znet)
        auc_results = np.append(auc_results, np.array([[res1, res2, res3, res4, res5]]),
            axis=0)
    print(auc_results)