in src/models.py [0:0]
def get_candidates(self, chunk_begin=None, chunk_size=None, target='rhs', indices=None):
if target == 'rhs' or target == 'lhs': #TODO: extend to other models
if indices == None:
return self.embeddings[0].weight.data[
chunk_begin:chunk_begin + chunk_size
].transpose(0, 1)
else:
bsz = indices.shape[0]
num_cands = indices.shape[1]
if target == 'rhs':
indices = indices[:, num_cands//2:]
else:
indices = indices[:, 0:num_cands//2]
return self.embeddings[0].weight.data[indices.reshape(-1)].reshape(bsz, num_cands//2, -1)
elif target == 'rel':
return self.embeddings[1].weight.data[
chunk_begin:chunk_begin + chunk_size
].transpose(0, 1)