def __init__()

in python/dglke/models/general_models.py [0:0]


    def __init__(self, device, model_name, hidden_dim,
        double_entity_emb=False, double_relation_emb=False,
        gamma=0., batch_size=DEFAULT_INFER_BATCHSIZE):
        super(InferModel, self).__init__()

        self.device = device
        self.model_name = model_name
        entity_dim = 2 * hidden_dim if double_entity_emb else hidden_dim
        relation_dim = 2 * hidden_dim if double_relation_emb else hidden_dim

        self.entity_emb = InferEmbedding(device)
        self.relation_emb = InferEmbedding(device)
        self.batch_size = batch_size

        if model_name == 'TransE' or model_name == 'TransE_l2':
            self.score_func = TransEScore(gamma, 'l2')
        elif model_name == 'TransE_l1':
            self.score_func = TransEScore(gamma, 'l1')
        elif model_name == 'TransR':
            assert False, 'Do not support inference of TransR model now.'
        elif model_name == 'DistMult':
            self.score_func = DistMultScore()
        elif model_name == 'ComplEx':
            self.score_func = ComplExScore()
        elif model_name == 'RESCAL':
            self.score_func = RESCALScore(relation_dim, entity_dim)
        elif model_name == 'RotatE':
            emb_init = (gamma + EMB_INIT_EPS) / hidden_dim
            self.score_func = RotatEScore(gamma, emb_init)
        elif model_name == 'SimplE':
            self.score_func = SimplEScore()