in hype_kg/codes/model.py [0:0]
def train_step(model, optimizer, train_iterator, args, step):
model.train()
optimizer.zero_grad()
positive_sample, negative_sample, subsampling_weight, mode = next(train_iterator)
if args.cuda:
positive_sample = positive_sample.cuda()
negative_sample = negative_sample.cuda()
subsampling_weight = subsampling_weight.cuda()
rel_len = int(train_iterator.qtype.split('-')[0])
qtype = train_iterator.qtype
negative_score, negative_score_cen, negative_offset, negative_score_cen_plus, _, _ = model((positive_sample, negative_sample), rel_len, qtype, mode=mode)
if model.geo == 'box':
negative_score = F.logsigmoid(-negative_score_cen_plus).mean(dim = 1)
else:
negative_score = F.logsigmoid(-negative_score).mean(dim = 1)
positive_score, positive_score_cen, positive_offset, positive_score_cen_plus, _, _ = model(positive_sample, rel_len, qtype)
if model.geo == 'box':
positive_score = F.logsigmoid(positive_score_cen_plus).squeeze(dim = 1)
else:
positive_score = F.logsigmoid(positive_score).squeeze(dim = 1)
if args.uni_weight:
positive_sample_loss = - positive_score.mean()
negative_sample_loss = - negative_score.mean()
else:
positive_sample_loss = - (subsampling_weight * positive_score).sum()
negative_sample_loss = - (subsampling_weight * negative_score).sum()
positive_sample_loss /= subsampling_weight.sum()
negative_sample_loss /= subsampling_weight.sum()
loss = (positive_sample_loss + negative_sample_loss)/2
if args.regularization != 0.0:
regularization = args.regularization * (
model.entity_embedding.norm(p = 3)**3 +
model.relation_embedding.norm(p = 3).norm(p = 3)**3
)
loss = loss + regularization
regularization_log = {'regularization': regularization.item()}
else:
regularization_log = {}
loss.backward()
optimizer.step(manifold = model.manifold)
log = {
**regularization_log,
'positive_sample_loss': positive_sample_loss.item(),
'negative_sample_loss': negative_sample_loss.item(),
'loss': loss.item()
}
return log