in tools/visualize.py [0:0]
def dlrm_output_wrap(dlrm, X, lS_o, lS_i, T):
all_feat_vec = []
all_cat_vec = []
x_vec = None
t_out = None
c_out = None
z_out = []
p_out = None
z_size = len(dlrm.top_l)
x = dlrm.apply_mlp(X, dlrm.bot_l)
# debug prints
#print("intermediate")
#print(x[0].detach().cpu().numpy())
x_vec = x[0].detach().cpu().numpy()
all_feat_vec.append(x_vec)
# all_X.append(x[0].detach().cpu().numpy())
# process sparse features(using embeddings), resulting in a list of row vectors
ly = dlrm.apply_emb(lS_o, lS_i, dlrm.emb_l)
for e in ly:
#print(e.detach().cpu().numpy())
all_feat_vec.append(e[0].detach().cpu().numpy())
all_cat_vec.append(e[0].detach().cpu().numpy())
all_feat_vec= np.concatenate(all_feat_vec, axis=0)
all_cat_vec= np.concatenate(all_cat_vec, axis=0)
# all_features.append(all_feat_vec)
# all_cat.append(all_cat_vec)
t_out = int(T.detach().cpu().numpy()[0,0])
# all_T.append(int(T.detach().cpu().numpy()[0,0]))
z = dlrm.interact_features(x, ly)
# print(z.detach().cpu().numpy())
# z_out = z.detach().cpu().numpy().flatten()
z_out.append(z.detach().cpu().numpy().flatten())
# all_z[0].append(z.detach().cpu().numpy().flatten())
# obtain probability of a click (using top mlp)
# print(dlrm.top_l)
# p = dlrm.apply_mlp(z, dlrm.top_l)
for i in range(0, z_size):
z = dlrm.top_l[i](z)
# if i < z_size-1:
# curr_z = z.detach().cpu().numpy().flatten()
z_out.append(z.detach().cpu().numpy().flatten())
# all_z[i+1].append(curr_z)
# print("z append", i)
# print("z",i, z.detach().cpu().numpy().flatten().shape)
p = z
# clamp output if needed
if 0.0 < dlrm.loss_threshold and dlrm.loss_threshold < 1.0:
z = torch.clamp(p, min=dlrm.loss_threshold, max=(1.0 - dlrm.loss_threshold))
else:
z = p
class_thresh = 0.0 #-0.25
zp = z.detach().cpu().numpy()[0,0]+ class_thresh
p_out = int(zp+0.5)
if p_out > 1:
p_out = 1
if p_out < 0:
p_out = 0
# all_pred.append(int(z.detach().cpu().numpy()[0,0]+0.5))
#print(int(z.detach().cpu().numpy()[0,0]+0.5))
if int(p_out) == t_out:
c_out = 0
else:
c_out = 1
return all_feat_vec, x_vec, all_cat_vec, t_out, c_out, z_out, p_out