in hugegraph-ml/src/hugegraph_ml/models/correct_and_smooth.py [0:0]
def correct(self, g, y_soft, y_true, mask):
with g.local_scope():
assert abs(float(y_soft.sum()) / y_soft.size(0) - 1.0) < 1e-2
numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0)
assert y_true.size(0) == numel
if y_true.dtype == torch.long:
y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)).to(y_soft.dtype)
error = torch.zeros_like(y_soft)
error[mask] = y_true - y_soft[mask]
if self.autoscale:
smoothed_error = self.prop1(
g, error, post_step=lambda x: x.clamp_(-1.0, 1.0)
)
sigma = error[mask].abs().sum() / numel
scale = sigma / smoothed_error.abs().sum(dim=1, keepdim=True)
scale[scale.isinf() | (scale > 1000)] = 1.0
result = y_soft + scale * smoothed_error
result[result.isnan()] = y_soft[result.isnan()]
return result
else:
def fix_input(x):
x[mask] = error[mask]
return x
smoothed_error = self.prop1(g, error, post_step=fix_input)
result = y_soft + self.scale * smoothed_error
result[result.isnan()] = y_soft[result.isnan()]
return result