in causalml/inference/nn/dragonnet.py [0:0]
def fit(self, X, treatment, y):
"""
Fits the DragonNet model.
Args:
X (np.matrix or np.array or pd.Dataframe): a feature matrix
treatment (np.array or pd.Series): a treatment vector
y (np.array or pd.Series): an outcome vector
"""
X, treatment, y = convert_pd_to_np(X, treatment, y)
y = np.hstack((y.reshape(-1, 1), treatment.reshape(-1, 1)))
self.dragonnet = self.make_dragonnet(X.shape[1])
metrics = [regression_loss, binary_classification_loss, treatment_accuracy, track_epsilon]
if self.targeted_reg:
loss = make_tarreg_loss(ratio=self.ratio, dragonnet_loss=self.loss_func)
else:
loss = self.loss_func
self.dragonnet.compile(
optimizer=Adam(lr=self.learning_rate),
loss=loss, metrics=metrics)
adam_callbacks = [
TerminateOnNaN(),
EarlyStopping(monitor='val_loss', patience=2, min_delta=0.),
ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, verbose=self.verbose, mode='auto',
min_delta=1e-8, cooldown=0, min_lr=0)
]
self.dragonnet.fit(X, y,
callbacks=adam_callbacks,
validation_split=self.val_split,
epochs=self.epochs,
batch_size=self.batch_size,
verbose=self.verbose)
sgd_callbacks = [
TerminateOnNaN(),
EarlyStopping(monitor='val_loss', patience=40, min_delta=0.),
ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, verbose=self.verbose, mode='auto',
min_delta=0., cooldown=0, min_lr=0)
]
sgd_lr = 1e-5
momentum = 0.9
self.dragonnet.compile(optimizer=SGD(lr=sgd_lr, momentum=momentum, nesterov=True), loss=loss, metrics=metrics)
self.dragonnet.fit(X, y,
callbacks=sgd_callbacks,
validation_split=self.val_split,
epochs=300,
batch_size=self.batch_size,
verbose=self.verbose)