def __init__()

in hugegraph-ml/src/hugegraph_ml/models/jknet.py [0:0]


    def __init__(self, n_in_feats, n_out_feats, n_hidden=32, n_layers=6, mode="cat", dropout=0.5):
        super(JKNet, self).__init__()
        self.mode = mode
        self.dropout = nn.Dropout(dropout)  # Dropout layer to prevent overfitting

        self.layers = nn.ModuleList()  # List to hold GraphConv layers
        # Add the first GraphConv layer (input layer)
        self.layers.append(GraphConv(n_in_feats, n_hidden, activation=F.relu))
        # Add additional GraphConv layers (hidden layers)
        for _ in range(n_layers):
            self.layers.append(GraphConv(n_hidden, n_hidden, activation=F.relu))

        # Initialize Jumping Knowledge module
        if self.mode == "lstm":
            self.jump = JumpingKnowledge(mode, n_hidden, n_layers)  # JKNet with LSTM for aggregating representations
        else:
            # JKNet with concatenation or max pooling for aggregating representations
            self.jump = JumpingKnowledge(mode)

        # Adjust hidden size for concatenation mode
        if self.mode == "cat":
            # Multiply by (n_layers + 1) because all layer outputs are concatenated
            n_hidden = n_hidden * (n_layers + 1)

        # Output layer for final prediction
        self.output_layer = nn.Linear(n_hidden, n_out_feats)
        self.criterion = nn.CrossEntropyLoss()
        self.reset_params()  # Initialize the model parameters