in hugegraph-ml/src/hugegraph_ml/models/jknet.py [0:0]
def __init__(self, n_in_feats, n_out_feats, n_hidden=32, n_layers=6, mode="cat", dropout=0.5):
super(JKNet, self).__init__()
self.mode = mode
self.dropout = nn.Dropout(dropout) # Dropout layer to prevent overfitting
self.layers = nn.ModuleList() # List to hold GraphConv layers
# Add the first GraphConv layer (input layer)
self.layers.append(GraphConv(n_in_feats, n_hidden, activation=F.relu))
# Add additional GraphConv layers (hidden layers)
for _ in range(n_layers):
self.layers.append(GraphConv(n_hidden, n_hidden, activation=F.relu))
# Initialize Jumping Knowledge module
if self.mode == "lstm":
self.jump = JumpingKnowledge(mode, n_hidden, n_layers) # JKNet with LSTM for aggregating representations
else:
# JKNet with concatenation or max pooling for aggregating representations
self.jump = JumpingKnowledge(mode)
# Adjust hidden size for concatenation mode
if self.mode == "cat":
# Multiply by (n_layers + 1) because all layer outputs are concatenated
n_hidden = n_hidden * (n_layers + 1)
# Output layer for final prediction
self.output_layer = nn.Linear(n_hidden, n_out_feats)
self.criterion = nn.CrossEntropyLoss()
self.reset_params() # Initialize the model parameters