in python/dglke/models/general_models.py [0:0]
def __init__(self, args, model_name, n_entities, n_relations, hidden_dim, gamma,
double_entity_emb=False, double_relation_emb=False):
super(KEModel, self).__init__()
self.args = args
self.has_edge_importance = args.has_edge_importance
self.n_entities = n_entities
self.n_relations = n_relations
self.model_name = model_name
self.hidden_dim = hidden_dim
self.eps = EMB_INIT_EPS
self.emb_init = (gamma + self.eps) / hidden_dim
entity_dim = 2 * hidden_dim if double_entity_emb else hidden_dim
relation_dim = 2 * hidden_dim if double_relation_emb else hidden_dim
device = get_device(args)
self.loss_gen = LossGenerator(args,
args.loss_genre if hasattr(args, 'loss_genre') else 'Logsigmoid',
args.neg_adversarial_sampling if hasattr(args, 'neg_adversarial_sampling') else False,
args.adversarial_temperature if hasattr(args, 'adversarial_temperature') else 1.0,
args.pairwise if hasattr(args, 'pairwise') else False)
self.entity_emb = ExternalEmbedding(args, n_entities, entity_dim,
F.cpu() if args.mix_cpu_gpu else device)
# For RESCAL, relation_emb = relation_dim * entity_dim
if model_name == 'RESCAL':
rel_dim = relation_dim * entity_dim
else:
rel_dim = relation_dim
self.rel_dim = rel_dim
self.entity_dim = entity_dim
self.strict_rel_part = args.strict_rel_part
self.soft_rel_part = args.soft_rel_part
if not self.strict_rel_part and not self.soft_rel_part:
self.relation_emb = ExternalEmbedding(args, n_relations, rel_dim,
F.cpu() if args.mix_cpu_gpu else device)
else:
self.global_relation_emb = ExternalEmbedding(args, n_relations, rel_dim, F.cpu())
if model_name == 'TransE' or model_name == 'TransE_l2':
self.score_func = TransEScore(gamma, 'l2')
elif model_name == 'TransE_l1':
self.score_func = TransEScore(gamma, 'l1')
elif model_name == 'TransR':
projection_emb = ExternalEmbedding(args,
n_relations,
entity_dim * relation_dim,
F.cpu() if args.mix_cpu_gpu else device)
self.score_func = TransRScore(gamma, projection_emb, relation_dim, entity_dim)
elif model_name == 'DistMult':
self.score_func = DistMultScore()
elif model_name == 'ComplEx':
self.score_func = ComplExScore()
elif model_name == 'RESCAL':
self.score_func = RESCALScore(relation_dim, entity_dim)
elif model_name == 'RotatE':
self.score_func = RotatEScore(gamma, self.emb_init)
elif model_name == 'SimplE':
self.score_func = SimplEScore()
self.model_name = model_name
self.head_neg_score = self.score_func.create_neg(True)
self.tail_neg_score = self.score_func.create_neg(False)
self.head_neg_prepare = self.score_func.create_neg_prepare(True)
self.tail_neg_prepare = self.score_func.create_neg_prepare(False)
self.reset_parameters()