def fit()

in causalml/inference/nn/dragonnet.py [0:0]


    def fit(self, X, treatment, y):
        """
        Fits the DragonNet model.

        Args:
            X (np.matrix or np.array or pd.Dataframe): a feature matrix
            treatment (np.array or pd.Series): a treatment vector
            y (np.array or pd.Series): an outcome vector
        """
        X, treatment, y = convert_pd_to_np(X, treatment, y)

        y = np.hstack((y.reshape(-1, 1), treatment.reshape(-1, 1)))

        self.dragonnet = self.make_dragonnet(X.shape[1])

        metrics = [regression_loss, binary_classification_loss, treatment_accuracy, track_epsilon]

        if self.targeted_reg:
            loss = make_tarreg_loss(ratio=self.ratio, dragonnet_loss=self.loss_func)
        else:
            loss = self.loss_func

        self.dragonnet.compile(
            optimizer=Adam(lr=self.learning_rate),
            loss=loss, metrics=metrics)

        adam_callbacks = [
            TerminateOnNaN(),
            EarlyStopping(monitor='val_loss', patience=2, min_delta=0.),
            ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, verbose=self.verbose, mode='auto',
                              min_delta=1e-8, cooldown=0, min_lr=0)

        ]

        self.dragonnet.fit(X, y,
                           callbacks=adam_callbacks,
                           validation_split=self.val_split,
                           epochs=self.epochs,
                           batch_size=self.batch_size,
                           verbose=self.verbose)

        sgd_callbacks = [
            TerminateOnNaN(),
            EarlyStopping(monitor='val_loss', patience=40, min_delta=0.),
            ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, verbose=self.verbose, mode='auto',
                              min_delta=0., cooldown=0, min_lr=0)
        ]

        sgd_lr = 1e-5
        momentum = 0.9
        self.dragonnet.compile(optimizer=SGD(lr=sgd_lr, momentum=momentum, nesterov=True), loss=loss, metrics=metrics)
        self.dragonnet.fit(X, y,
                           callbacks=sgd_callbacks,
                           validation_split=self.val_split,
                           epochs=300,
                           batch_size=self.batch_size,
                           verbose=self.verbose)