in models.py [0:0]
def __init__(self, hidden_channels, num_layers, max_z, train_dataset,
use_feature=False, node_embedding=None, dropout=0.5,
jk=True, train_eps=False):
super(GIN, self).__init__()
self.use_feature = use_feature
self.node_embedding = node_embedding
self.max_z = max_z
self.z_embedding = Embedding(self.max_z, hidden_channels)
self.jk = jk
initial_channels = hidden_channels
if self.use_feature:
initial_channels += train_dataset.num_features
if self.node_embedding is not None:
initial_channels += node_embedding.embedding_dim
self.conv1 = GINConv(
Sequential(
Linear(initial_channels, hidden_channels),
ReLU(),
Linear(hidden_channels, hidden_channels),
ReLU(),
BN(hidden_channels),
),
train_eps=train_eps)
self.convs = torch.nn.ModuleList()
for i in range(num_layers - 1):
self.convs.append(
GINConv(
Sequential(
Linear(hidden_channels, hidden_channels),
ReLU(),
Linear(hidden_channels, hidden_channels),
ReLU(),
BN(hidden_channels),
),
train_eps=train_eps))
self.dropout = dropout
if self.jk:
self.lin1 = Linear(num_layers * hidden_channels, hidden_channels)
else:
self.lin1 = Linear(hidden_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, 1)