in hype_kg/codes/model.py [0:0]
def __init__(self, model_name, nentity, nrelation, hidden_dim, gamma,
writer=None, geo=None,
cen=None, offset_deepsets=None,
center_deepsets=None, offset_use_center=None, center_use_offset=None,
att_reg = 0., off_reg = 0., att_tem = 1., euo = False,
gamma2=0, bn='no', nat=1, activation='relu', manifold = 'poincare',
curvature = 1.0, trainable_curvature = False,
use_semantics = False):
super(Query2Manifold, self).__init__()
self.model_name = model_name
self.nentity = nentity
self.nrelation = nrelation
self.hidden_dim = hidden_dim
self.epsilon = 2.0
self.writer=writer
self.geo = geo
self.cen = cen
self.offset_deepsets = offset_deepsets
self.center_deepsets = center_deepsets
self.offset_use_center = offset_use_center
self.center_use_offset = center_use_offset
self.att_reg = att_reg
self.off_reg = off_reg
self.att_tem = att_tem
self.euo = euo
self.his_step = 0
self.bn = bn
self.nat = nat
if activation == 'none':
self.func = Identity
elif activation == 'relu':
self.func = F.relu
elif activation == 'softplus':
self.func = F.softplus
if trainable_curvature:
self.c = nn.Parameter(torch.FloatTensor([curvature]),requires_grad=True)
else:
self.c = nn.Parameter(torch.FloatTensor([curvature]),requires_grad=False)
if manifold == 'poincare':
self.manifold = poincare.PoincareBall(self.c)
elif manifold == 'lorentz':
self.manifold = lorentz.Lorentz(self.c)
else:
self.manifold = euclidean.Euclidean()
self.gamma = ManifoldParameter(
torch.Tensor([gamma]),
requires_grad=False,
manifold = self.manifold
)
if gamma2 == 0:
gamma2 = gamma
self.gamma2 = ManifoldParameter(
torch.Tensor([gamma2]),
requires_grad=False,
manifold = self.manifold
)
self.embedding_range = ManifoldParameter(
torch.Tensor([(self.gamma.item() + self.epsilon) / hidden_dim]),
requires_grad=False,
manifold = self.manifold
)
self.entity_dim = hidden_dim
self.relation_dim = hidden_dim
if use_semantics:
#Key: Entity Id, Value: Semantic Vector
semantic_vectors = pickle.load(open("get_vectors.pkl","rb"))
# Semantic Vector dimensions should match the vector dimensions of box centers
assert(len(next(iter(semantic_vectors.values()))) == hidden_dim)
self.entity_embedding = ManifoldParameter(torch.stack([torch.Tensor(
semantic_vectors[entity_id])
for entity_id in range(nentity)]),
manifold=self.manifold)
else:
self.entity_embedding = ManifoldParameter(torch.zeros(nentity, self.entity_dim), manifold=self.manifold)
nn.init.uniform_(
tensor=self.entity_embedding,
a=-self.embedding_range.item(),
b=self.embedding_range.item()
)
self.relation_embedding = ManifoldParameter(torch.zeros(nrelation, self.relation_dim), manifold=self.manifold)
nn.init.uniform_(
tensor=self.relation_embedding,
a=-self.embedding_range.item(),
b=self.embedding_range.item()
)
if self.geo == 'vec':
if self.center_deepsets == 'vanilla':
self.deepsets = CenterSet(self.manifold, self.relation_dim, self.relation_dim, False, agg_func = torch.mean, bn=bn, nat=nat)
elif self.center_deepsets == 'attention':
self.deepsets = AttentionSet(self.manifold, self.relation_dim, self.relation_dim, False,
att_reg = self.att_reg, att_tem = self.att_tem, bn=bn, nat=nat)
elif self.center_deepsets == 'eleattention':
self.deepsets = AttentionSet(self.manifold, self.relation_dim, self.relation_dim, False,
att_reg = self.att_reg, att_type='ele', att_tem=self.att_tem, bn=bn, nat=nat)
elif self.center_deepsets == 'mean':
self.deepsets = MeanSet(self.manifold)
else:
assert False
if self.geo == 'box':
self.offset_embedding = ManifoldParameter(torch.zeros(nrelation, self.entity_dim), manifold=self.manifold)
nn.init.uniform_(
tensor=self.offset_embedding,
a=0.,
b=self.embedding_range.item()
)
if self.euo:
self.entity_offset_embedding = ManifoldParameter(torch.zeros(nentity, self.entity_dim), manifold=self.manifold)
nn.init.uniform_(
tensor=self.entity_offset_embedding,
a=0.,
b=self.embedding_range.item()
)
if self.center_deepsets == 'vanilla':
self.center_sets = CenterSet(self.manifold, self.relation_dim, self.relation_dim, self.center_use_offset, agg_func = torch.mean, bn=bn, nat=nat)
elif self.center_deepsets == 'attention':
self.center_sets = AttentionSet(self.manifold, self.relation_dim, self.relation_dim, self.center_use_offset,
att_reg = self.att_reg, att_tem = self.att_tem, bn=bn, nat=nat)
elif self.center_deepsets == 'eleattention':
self.center_sets = AttentionSet(self.manifold, self.relation_dim, self.relation_dim, self.center_use_offset,
att_reg = self.att_reg, att_type='ele', att_tem=self.att_tem, bn=bn, nat=nat)
elif self.center_deepsets == 'mean':
self.center_sets = MeanSet(self.manifold)
else:
assert False
if self.offset_deepsets == 'vanilla':
self.offset_sets = OffsetSet(self.manifold, self.relation_dim, self.relation_dim, self.offset_use_center, agg_func = torch.mean)
elif self.offset_deepsets == 'inductive':
self.offset_sets = InductiveOffsetSet(self.manifold, self.relation_dim, self.relation_dim, self.offset_use_center, self.off_reg, agg_func=torch.mean)
elif self.offset_deepsets == 'min':
self.offset_sets = MinSet(self.manifold)
else:
assert False
if model_name not in ['TransE', 'BoxTransE']:
raise ValueError('model %s not supported' % model_name)