in evaluations/evaluator.py [0:0]
def __init__(self, session):
self.session = session
# Initialize TF graph to calculate pairwise distances.
with session.graph.as_default():
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
distance_block_16 = _batch_pairwise_distances(
tf.cast(self._features_batch1, tf.float16),
tf.cast(self._features_batch2, tf.float16),
)
self.distance_block = tf.cond(
tf.reduce_all(tf.math.is_finite(distance_block_16)),
lambda: tf.cast(distance_block_16, tf.float32),
lambda: _batch_pairwise_distances(
self._features_batch1, self._features_batch2
),
)
# Extra logic for less thans.
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
self._batch_2_in = tf.math.reduce_any(
dist32 <= self._radii1[:, None], axis=0
)