src/sagemaker_sklearn_extension/contrib/taei/models.py [264:271]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        x_hat = output[0]
        if self.continuous_features:
            out = x_hat.pop(0)
            loss["mse"] = nn.functional.mse_loss(target[:, self.continuous_features], out)
        if self.categorical_features:
            for idx in self.categorical_features:
                out = x_hat.pop(0)
                loss["nll"] += nn.functional.nll_loss(out, target[:, idx].long())
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



src/sagemaker_sklearn_extension/contrib/taei/models.py [394:401]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        x_hat = output[0]
        if self.continuous_features:
            out = x_hat.pop(0)
            loss["mse"] = nn.functional.mse_loss(target[:, self.continuous_features], out)
        if self.categorical_features:
            for idx in self.categorical_features:
                out = x_hat.pop(0)
                loss["nll"] += nn.functional.nll_loss(out, target[:, idx].long())
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



