in python/dglke/models/general_models.py [0:0]
def __init__(self, device, model_name, hidden_dim,
double_entity_emb=False, double_relation_emb=False,
gamma=0., batch_size=DEFAULT_INFER_BATCHSIZE):
super(InferModel, self).__init__()
self.device = device
self.model_name = model_name
entity_dim = 2 * hidden_dim if double_entity_emb else hidden_dim
relation_dim = 2 * hidden_dim if double_relation_emb else hidden_dim
self.entity_emb = InferEmbedding(device)
self.relation_emb = InferEmbedding(device)
self.batch_size = batch_size
if model_name == 'TransE' or model_name == 'TransE_l2':
self.score_func = TransEScore(gamma, 'l2')
elif model_name == 'TransE_l1':
self.score_func = TransEScore(gamma, 'l1')
elif model_name == 'TransR':
assert False, 'Do not support inference of TransR model now.'
elif model_name == 'DistMult':
self.score_func = DistMultScore()
elif model_name == 'ComplEx':
self.score_func = ComplExScore()
elif model_name == 'RESCAL':
self.score_func = RESCALScore(relation_dim, entity_dim)
elif model_name == 'RotatE':
emb_init = (gamma + EMB_INIT_EPS) / hidden_dim
self.score_func = RotatEScore(gamma, emb_init)
elif model_name == 'SimplE':
self.score_func = SimplEScore()