in src/utils.py [0:0]
def normalize_embeddings(emb, types, mean=None):
"""
Normalize embeddings by their norms / recenter them.
"""
for t in types.split(','):
if t == '':
continue
if t == 'center':
if mean is None:
mean = emb.mean(0, keepdim=True)
emb.sub_(mean.expand_as(emb))
elif t == 'renorm':
emb.div_(emb.norm(2, 1, keepdim=True).expand_as(emb))
else:
raise Exception('Unknown normalization type: "%s"' % t)
return mean.cpu() if mean is not None else None