def fit()

in causalml/inference/tf/dragonnet.py [0:0]


    def fit(self, X, treatment, y, p=None):
        """
        Fits the DragonNet model.

        Args:
            X (np.matrix or np.array or pd.Dataframe): a feature matrix
            treatment (np.array or pd.Series): a treatment vector
            y (np.array or pd.Series): an outcome vector
        """
        X, treatment, y = convert_pd_to_np(X, treatment, y)

        y = np.hstack((y.reshape(-1, 1), treatment.reshape(-1, 1)))

        self.dragonnet = self.make_dragonnet(X.shape[1])

        metrics = [
            regression_loss,
            binary_classification_loss,
            treatment_accuracy,
            track_epsilon,
        ]

        if self.targeted_reg:
            loss = make_tarreg_loss(ratio=self.ratio, dragonnet_loss=self.loss_func)
        else:
            loss = self.loss_func

        if self.use_adam:
            self.dragonnet.compile(
                optimizer=Adam(learning_rate=self.adam_learning_rate),
                loss=loss,
                metrics=metrics,
            )

            adam_callbacks = [
                TerminateOnNaN(),
                EarlyStopping(monitor="val_loss", patience=2, min_delta=0.0),
                ReduceLROnPlateau(
                    monitor="loss",
                    factor=0.5,
                    patience=5,
                    verbose=self.verbose,
                    mode="auto",
                    min_delta=1e-8,
                    cooldown=0,
                    min_lr=0,
                ),
            ]

            self.dragonnet.fit(
                X,
                y,
                callbacks=adam_callbacks,
                validation_split=self.val_split,
                epochs=self.adam_epochs,
                batch_size=self.batch_size,
                verbose=self.verbose,
            )

        sgd_callbacks = [
            TerminateOnNaN(),
            EarlyStopping(monitor="val_loss", patience=40, min_delta=0.0),
            ReduceLROnPlateau(
                monitor="loss",
                factor=0.5,
                patience=5,
                verbose=self.verbose,
                mode="auto",
                min_delta=0.0,
                cooldown=0,
                min_lr=0,
            ),
        ]

        self.dragonnet.compile(
            optimizer=SGD(
                learning_rate=self.learning_rate, momentum=self.momentum, nesterov=True
            ),
            loss=loss,
            metrics=metrics,
        )
        self.dragonnet.fit(
            X,
            y,
            callbacks=sgd_callbacks,
            validation_split=self.val_split,
            epochs=self.epochs,
            batch_size=self.batch_size,
            verbose=self.verbose,
        )