in hugegraph-ml/src/hugegraph_ml/models/gin_global_pool.py [0:0]
def __init__(self, n_in_feats, n_out_feats, n_hidden=16, n_layers=5, p_drop=0.5, pooling="sum"):
super().__init__()
self.gin_layers = nn.ModuleList()
self.batch_norms = nn.ModuleList()
self.criterion = nn.CrossEntropyLoss()
# five-layer GCN with two-layer MLP aggregator and sum-neighbor-pooling scheme
assert n_layers >= 2, "The number of GIN layers must be at least 2."
for layer in range(n_layers - 1):
if layer == 0:
mlp = _MLP(n_in_feats, n_hidden, n_hidden)
else:
mlp = _MLP(n_hidden, n_hidden, n_hidden)
self.gin_layers.append(
GINConv(mlp, learn_eps=False)
) # set to True if learning epsilon
self.batch_norms.append(nn.BatchNorm1d(n_hidden))
# linear functions for graph sum pooling of output of each layer
self.linear_prediction = nn.ModuleList()
for layer in range(n_layers):
if layer == 0:
self.linear_prediction.append(nn.Linear(n_in_feats, n_out_feats))
else:
# adapt set2set pooling dim
if pooling == "set2set":
self.linear_prediction.append(nn.Linear(2 * n_hidden, n_out_feats))
else:
self.linear_prediction.append(nn.Linear(n_hidden, n_out_feats))
self.drop = nn.Dropout(p_drop)
if pooling == "sum":
self.pool = SumPooling()
elif pooling == "mean":
self.pool = AvgPooling()
elif pooling == "max":
self.pool = MaxPooling()
elif pooling == "global_attention":
gate_nn = nn.Linear(n_hidden, 1)
self.pool = GlobalAttentionPooling(gate_nn)
elif pooling == "set2set":
self.pool = Set2Set(n_hidden, n_iters=2, n_layers=1)
else:
raise ValueError(f"Unsupported pooling type: {pooling}")