in bert_layer.py [0:0]
def call(self, inputs):
inputs = [K.cast(x, dtype="int32") for x in inputs]
input_ids, input_mask, segment_ids = inputs
bert_inputs = dict(
input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids
)
if self.pooling == "first":
pooled = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)[
"pooled_output"
]
elif self.pooling == "mean":
result = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)[
"sequence_output"
]
# mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1)
# masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / (
# tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10)
# input_mask = tf.cast(input_mask, tf.float32)
# pooled = masked_reduce_mean(result, input_mask)
pooled = result
else:
raise NameError(f"Undefined pooling type (must be either first or mean, but is {self.pooling}")
return pooled