in models.py [0:0]
def __init__(self, hidden_channels, num_layers, max_z, train_dataset=None,
use_feature=False, node_embedding=None, dropout=0.5):
super(SAGE, self).__init__()
self.use_feature = use_feature
self.node_embedding = node_embedding
self.max_z = max_z
self.z_embedding = Embedding(self.max_z, hidden_channels)
self.convs = ModuleList()
initial_channels = hidden_channels
if self.use_feature:
initial_channels += train_dataset.num_features
if self.node_embedding is not None:
initial_channels += node_embedding.embedding_dim
self.convs.append(SAGEConv(initial_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.dropout = dropout
self.lin1 = Linear(hidden_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, 1)