def call()

in bert_layer.py [0:0]


    def call(self, inputs):

        inputs = [K.cast(x, dtype="int32") for x in inputs]
        input_ids, input_mask, segment_ids = inputs
        bert_inputs = dict(
            input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids
        )
        if self.pooling == "first":
            pooled = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)[
                "pooled_output"
            ]
        elif self.pooling == "mean":
            result = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)[
                "sequence_output"
            ]

            # mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1)
            # masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / (
            #         tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10)
            # input_mask = tf.cast(input_mask, tf.float32)
            # pooled = masked_reduce_mean(result, input_mask)
            pooled = result
        else:
            raise NameError(f"Undefined pooling type (must be either first or mean, but is {self.pooling}")

        return pooled