ludwig/features/numerical_feature.py [40:82]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
logger = logging.getLogger(__name__)


# TODO TF2 can we eliminate use of these custom wrapper classes?
# custom class to handle how Ludwig stores predictions
class MSELoss(MeanSquaredError):
    def __init__(self, **kwargs):
        super(MSELoss, self).__init__(**kwargs)

    def __call__(self, y_true, y_pred, sample_weight=None):
        logits = y_pred[LOGITS]
        loss = super().__call__(y_true, logits, sample_weight=sample_weight)
        return loss


class MSEMetric(MeanSquaredErrorMetric):
    def __init__(self, **kwargs):
        super(MSEMetric, self).__init__(**kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        super().update_state(
            y_true, y_pred[PREDICTIONS], sample_weight=sample_weight
        )


class MAELoss(MeanAbsoluteError):
    def __init__(self, **kwargs):
        super(MAELoss, self).__init__(**kwargs)

    def __call__(self, y_true, y_pred, sample_weight=None):
        logits = y_pred[LOGITS]
        loss = super().__call__(y_true, logits, sample_weight=sample_weight)
        return loss


class MAEMetric(MeanAbsoluteErrorMetric):
    def __init__(self, **kwargs):
        super(MAEMetric, self).__init__(**kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        super().update_state(
            y_true, y_pred[PREDICTIONS], sample_weight=sample_weight
        )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



ludwig/features/vector_feature.py [41:85]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
logger = logging.getLogger(__name__)


# TODO TF2 can we eliminate use of these customer wrapper classes?
#  These are copies of the classes in numerical_modules,
#  depending on what we end up doing with those, these will follow
# custom class to handle how Ludwig stores predictions
class MSELoss(MeanSquaredError):
    def __init__(self, **kwargs):
        super(MSELoss, self).__init__(**kwargs)

    def __call__(self, y_true, y_pred, sample_weight=None):
        logits = y_pred[LOGITS]
        loss = super().__call__(y_true, logits, sample_weight=sample_weight)
        return loss


class MSEMetric(MeanSquaredErrorMetric):
    def __init__(self, **kwargs):
        super(MSEMetric, self).__init__(**kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        super().update_state(
            y_true, y_pred[PREDICTIONS], sample_weight=sample_weight
        )


class MAELoss(MeanAbsoluteError):
    def __init__(self, **kwargs):
        super(MAELoss, self).__init__(**kwargs)

    def __call__(self, y_true, y_pred, sample_weight=None):
        logits = y_pred[LOGITS]
        loss = super().__call__(y_true, logits, sample_weight=sample_weight)
        return loss


class MAEMetric(MeanAbsoluteErrorMetric):
    def __init__(self, **kwargs):
        super(MAEMetric, self).__init__(**kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        super().update_state(
            y_true, y_pred[PREDICTIONS], sample_weight=sample_weight
        )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



