in product_matching/hyperboloid.py [0:0]
def intersection_layer(self,x):
# Intersection Layer with Attention on centers and Minimum on offsets
all_pairs = []
for ind1 in tqdm(range(x[0].shape[1])):
for ind2 in range(ind1, x[1].shape[1]):
box1_center = self.crop_box(ind1,0,self._params["box_dim"]//2)(x[0])
box1_offset = self.crop_box(ind1,self._params["box_dim"]//2,None)(x[0])
box2_center = self.crop_box(ind2,0,self._params["box_dim"]//2)(x[1])
box2_offset = self.crop_box(ind2,self._params["box_dim"]//2,None)(x[1])
concat_center = Concatenate()([box1_center,box2_center])
reshape_concat_center = Reshape((2,self._params["box_dim"]//2))(concat_center)
center = SelfAttention(return_sequences=True)(reshape_concat_center)
center = Dense(self._params["box_dim"]//2)(center)
#center = Add()([self.crop_box(i,None,None)(center) for i in range(center.shape[1])])
center = self.hyperbolic_add()([self.crop_box(i,None,None)(center) for i in range(center.shape[1])])
offset = Minimum()([box1_offset,box2_offset])
intersection = Concatenate(axis=1)([center,offset])
all_pairs.append(intersection)
return Concatenate()(all_pairs)