in src/engines.py [0:0]
def setup_model(opt):
if opt['model'] == 'TransE':
model = TransE(opt['size'], opt['rank'], opt['init'])
elif opt['model'] == 'ComplEx':
model = ComplEx(opt['size'], opt['rank'], opt['init'])
elif opt['model'] == 'TuckER':
model = TuckER(opt['size'], opt['rank'], opt['rank_r'], opt['init'], opt['dropout'])
elif opt['model'] == 'RESCAL':
model = RESCAL(opt['size'], opt['rank'], opt['init'])
elif opt['model'] == 'CP':
model = CP(opt['size'], opt['rank'], opt['init'])
model.to(opt['device'])
return model