def make_dragonnet()

in causalml/inference/nn/dragonnet.py [0:0]


    def make_dragonnet(self, input_dim):
        """
        Neural net predictive model. The dragon has three heads.

        Args:
            input_dim (int): number of rows in input
        Returns:
            model (keras.models.Model): DragonNet model
        """
        inputs = Input(shape=(input_dim,), name='input')

        # representation
        x = Dense(units=self.neurons_per_layer, activation='elu', kernel_initializer='RandomNormal')(inputs)
        x = Dense(units=self.neurons_per_layer, activation='elu', kernel_initializer='RandomNormal')(x)
        x = Dense(units=self.neurons_per_layer, activation='elu', kernel_initializer='RandomNormal')(x)

        t_predictions = Dense(units=1, activation='sigmoid')(x)

        # HYPOTHESIS
        y0_hidden = Dense(units=int(self.neurons_per_layer / 2),
                          activation='elu',
                          kernel_regularizer=regularizers.l2(self.reg_l2))(x)
        y1_hidden = Dense(units=int(self.neurons_per_layer/2),
                          activation='elu',
                          kernel_regularizer=regularizers.l2(self.reg_l2))(x)

        # second layer
        y0_hidden = Dense(units=int(self.neurons_per_layer/2),
                        activation='elu',
                        kernel_regularizer=regularizers.l2(self.reg_l2))(y0_hidden)
        y1_hidden = Dense(units=int(self.neurons_per_layer / 2),
                        activation='elu',
                        kernel_regularizer=regularizers.l2(self.reg_l2))(y1_hidden)

        # third
        y0_predictions = Dense(units=1,
                               activation=None,
                               kernel_regularizer=regularizers.l2(self.reg_l2),
                               name='y0_predictions')(y0_hidden)
        y1_predictions = Dense(units=1,
                               activation=None,
                               kernel_regularizer=regularizers.l2(self.reg_l2),
                               name='y1_predictions')(y1_hidden)

        dl = EpsilonLayer()
        epsilons = dl(t_predictions, name='epsilon')
        concat_pred = Concatenate(1)([y0_predictions, y1_predictions, t_predictions, epsilons])
        model = Model(inputs=inputs, outputs=concat_pred)

        return model