def decode_sample()

in src/sagemaker_sklearn_extension/contrib/taei/models.py [0:0]


    def decode_sample(self, z):
        x_hat = self.decode(z)
        x_cont, x_cat = [], []
        if hasattr(self, "cont_net"):
            x_cont = x_hat.pop(0)
        if hasattr(self, "cat_nets"):
            for _ in self.categorical_features:
                x_cat.append(torch.argmax(x_hat.pop(0), dim=1))
        x = []
        cont_c, cat_c = 0, 0
        for i in range(self.input_dim):
            if i in self.continuous_features:
                x.append(x_cont[:, cont_c].reshape(-1, 1))
                cont_c += 1
            elif i in self.categorical_features:
                x.append(x_cat[cat_c].reshape(-1, 1))
                cat_c += 1
        x = torch.cat(x, dim=1)
        return x