in src/sagemaker_sklearn_extension/contrib/taei/models.py [0:0]
def decode_sample(self, z):
x_hat = self.decode(z)
x_cont, x_cat = [], []
if hasattr(self, "cont_net"):
x_cont = x_hat.pop(0)
if hasattr(self, "cat_nets"):
for _ in self.categorical_features:
x_cat.append(torch.argmax(x_hat.pop(0), dim=1))
x = []
cont_c, cat_c = 0, 0
for i in range(self.input_dim):
if i in self.continuous_features:
x.append(x_cont[:, cont_c].reshape(-1, 1))
cont_c += 1
elif i in self.categorical_features:
x.append(x_cat[cat_c].reshape(-1, 1))
cat_c += 1
x = torch.cat(x, dim=1)
return x