def dlrm_output_wrap()

in tools/visualize.py [0:0]


def dlrm_output_wrap(dlrm, X, lS_o, lS_i, T):

    all_feat_vec = []
    all_cat_vec  = []
    x_vec        = None
    t_out        = None
    c_out        = None
    z_out        = []
    p_out        = None

    z_size = len(dlrm.top_l)

    x = dlrm.apply_mlp(X, dlrm.bot_l)
    # debug prints
    #print("intermediate")
    #print(x[0].detach().cpu().numpy())
    x_vec = x[0].detach().cpu().numpy()
    all_feat_vec.append(x_vec)
#    all_X.append(x[0].detach().cpu().numpy())

    # process sparse features(using embeddings), resulting in a list of row vectors
    ly = dlrm.apply_emb(lS_o, lS_i, dlrm.emb_l)

    for e in ly:
        #print(e.detach().cpu().numpy())
        all_feat_vec.append(e[0].detach().cpu().numpy())
        all_cat_vec.append(e[0].detach().cpu().numpy())

    all_feat_vec= np.concatenate(all_feat_vec, axis=0)
    all_cat_vec= np.concatenate(all_cat_vec, axis=0)

#    all_features.append(all_feat_vec)
#    all_cat.append(all_cat_vec)
    t_out = int(T.detach().cpu().numpy()[0,0])
#    all_T.append(int(T.detach().cpu().numpy()[0,0]))

    z = dlrm.interact_features(x, ly)
    # print(z.detach().cpu().numpy())
#    z_out = z.detach().cpu().numpy().flatten()
    z_out.append(z.detach().cpu().numpy().flatten())
#    all_z[0].append(z.detach().cpu().numpy().flatten())

        # obtain probability of a click (using top mlp)
#        print(dlrm.top_l)
#        p = dlrm.apply_mlp(z, dlrm.top_l)

    for i in range(0, z_size):
        z = dlrm.top_l[i](z)

#        if i < z_size-1:
#            curr_z = z.detach().cpu().numpy().flatten()
        z_out.append(z.detach().cpu().numpy().flatten())
#            all_z[i+1].append(curr_z)
#            print("z append", i)
            
#        print("z",i, z.detach().cpu().numpy().flatten().shape)

    p = z

    # clamp output if needed
    if 0.0 < dlrm.loss_threshold and dlrm.loss_threshold < 1.0:
        z = torch.clamp(p, min=dlrm.loss_threshold, max=(1.0 - dlrm.loss_threshold))
    else:
        z = p

    class_thresh = 0.0 #-0.25
    zp = z.detach().cpu().numpy()[0,0]+ class_thresh
    
    p_out = int(zp+0.5)
    if p_out > 1:
        p_out = 1
    if p_out < 0:
        p_out = 0

#    all_pred.append(int(z.detach().cpu().numpy()[0,0]+0.5))

    #print(int(z.detach().cpu().numpy()[0,0]+0.5))
    if int(p_out) == t_out:
        c_out = 0
    else:
        c_out = 1

    return all_feat_vec, x_vec, all_cat_vec, t_out, c_out, z_out, p_out