def make_dragonnet()

in causalml/inference/tf/dragonnet.py [0:0]


    def make_dragonnet(self, input_dim):
        """
        Neural net predictive model. The dragon has three heads.

        Args:
            input_dim (int): number of rows in input
        Returns:
            model (keras.models.Model): DragonNet model
        """
        inputs = Input(shape=(input_dim,), name="input")

        # representation
        x = Dense(
            units=self.neurons_per_layer,
            activation="elu",
            kernel_initializer="RandomNormal",
        )(inputs)
        x = Dense(
            units=self.neurons_per_layer,
            activation="elu",
            kernel_initializer="RandomNormal",
        )(x)
        x = Dense(
            units=self.neurons_per_layer,
            activation="elu",
            kernel_initializer="RandomNormal",
        )(x)

        t_predictions = Dense(units=1, activation="sigmoid")(x)

        # HYPOTHESIS
        y0_hidden = Dense(
            units=int(self.neurons_per_layer / 2),
            activation="elu",
            kernel_regularizer=l2(self.reg_l2),
        )(x)
        y1_hidden = Dense(
            units=int(self.neurons_per_layer / 2),
            activation="elu",
            kernel_regularizer=l2(self.reg_l2),
        )(x)

        # second layer
        y0_hidden = Dense(
            units=int(self.neurons_per_layer / 2),
            activation="elu",
            kernel_regularizer=l2(self.reg_l2),
        )(y0_hidden)
        y1_hidden = Dense(
            units=int(self.neurons_per_layer / 2),
            activation="elu",
            kernel_regularizer=l2(self.reg_l2),
        )(y1_hidden)

        # third
        y0_predictions = Dense(
            units=1,
            activation=None,
            kernel_regularizer=l2(self.reg_l2),
            name="y0_predictions",
        )(y0_hidden)
        y1_predictions = Dense(
            units=1,
            activation=None,
            kernel_regularizer=l2(self.reg_l2),
            name="y1_predictions",
        )(y1_hidden)

        dl = EpsilonLayer()
        epsilons = dl(t_predictions, name="epsilon")
        concat_pred = Concatenate(1)(
            [y0_predictions, y1_predictions, t_predictions, epsilons]
        )
        model = Model(inputs=inputs, outputs=concat_pred)

        return model