def __init__()

in hugegraph-ml/src/hugegraph_ml/models/gin_global_pool.py [0:0]


    def __init__(self, n_in_feats, n_out_feats, n_hidden=16, n_layers=5, p_drop=0.5, pooling="sum"):
        super().__init__()
        self.gin_layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.criterion = nn.CrossEntropyLoss()
        # five-layer GCN with two-layer MLP aggregator and sum-neighbor-pooling scheme
        assert n_layers >= 2, "The number of GIN layers must be at least 2."
        for layer in range(n_layers - 1):
            if layer == 0:
                mlp = _MLP(n_in_feats, n_hidden, n_hidden)
            else:
                mlp = _MLP(n_hidden, n_hidden, n_hidden)
            self.gin_layers.append(
                GINConv(mlp, learn_eps=False)
            )  # set to True if learning epsilon
            self.batch_norms.append(nn.BatchNorm1d(n_hidden))
        # linear functions for graph sum pooling of output of each layer
        self.linear_prediction = nn.ModuleList()
        for layer in range(n_layers):
            if layer == 0:
                self.linear_prediction.append(nn.Linear(n_in_feats, n_out_feats))
            else:
                # adapt set2set pooling dim
                if pooling == "set2set":
                    self.linear_prediction.append(nn.Linear(2 * n_hidden, n_out_feats))
                else:
                    self.linear_prediction.append(nn.Linear(n_hidden, n_out_feats))
        self.drop = nn.Dropout(p_drop)
        if pooling == "sum":
            self.pool = SumPooling()
        elif pooling == "mean":
            self.pool = AvgPooling()
        elif pooling == "max":
            self.pool = MaxPooling()
        elif pooling == "global_attention":
            gate_nn = nn.Linear(n_hidden, 1)
            self.pool = GlobalAttentionPooling(gate_nn)
        elif pooling == "set2set":
            self.pool = Set2Set(n_hidden, n_iters=2, n_layers=1)
        else:
            raise ValueError(f"Unsupported pooling type: {pooling}")