in causalml/inference/tf/dragonnet.py [0:0]
def fit(self, X, treatment, y, p=None):
"""
Fits the DragonNet model.
Args:
X (np.matrix or np.array or pd.Dataframe): a feature matrix
treatment (np.array or pd.Series): a treatment vector
y (np.array or pd.Series): an outcome vector
"""
X, treatment, y = convert_pd_to_np(X, treatment, y)
y = np.hstack((y.reshape(-1, 1), treatment.reshape(-1, 1)))
self.dragonnet = self.make_dragonnet(X.shape[1])
metrics = [
regression_loss,
binary_classification_loss,
treatment_accuracy,
track_epsilon,
]
if self.targeted_reg:
loss = make_tarreg_loss(ratio=self.ratio, dragonnet_loss=self.loss_func)
else:
loss = self.loss_func
if self.use_adam:
self.dragonnet.compile(
optimizer=Adam(learning_rate=self.adam_learning_rate),
loss=loss,
metrics=metrics,
)
adam_callbacks = [
TerminateOnNaN(),
EarlyStopping(monitor="val_loss", patience=2, min_delta=0.0),
ReduceLROnPlateau(
monitor="loss",
factor=0.5,
patience=5,
verbose=self.verbose,
mode="auto",
min_delta=1e-8,
cooldown=0,
min_lr=0,
),
]
self.dragonnet.fit(
X,
y,
callbacks=adam_callbacks,
validation_split=self.val_split,
epochs=self.adam_epochs,
batch_size=self.batch_size,
verbose=self.verbose,
)
sgd_callbacks = [
TerminateOnNaN(),
EarlyStopping(monitor="val_loss", patience=40, min_delta=0.0),
ReduceLROnPlateau(
monitor="loss",
factor=0.5,
patience=5,
verbose=self.verbose,
mode="auto",
min_delta=0.0,
cooldown=0,
min_lr=0,
),
]
self.dragonnet.compile(
optimizer=SGD(
learning_rate=self.learning_rate, momentum=self.momentum, nesterov=True
),
loss=loss,
metrics=metrics,
)
self.dragonnet.fit(
X,
y,
callbacks=sgd_callbacks,
validation_split=self.val_split,
epochs=self.epochs,
batch_size=self.batch_size,
verbose=self.verbose,
)