in torch_geometric_utils.py [0:0]
def propagate(self, edge_index, size=None, **kwargs):
r"""The initial call to start propagating messages.
Args:
edge_index (Tensor): The indices of a general (sparse) assignment
matrix with shape :obj:`[N, M]` (can be directed or
undirected).
size (list or tuple, optional): The size :obj:`[N, M]` of the
assignment matrix. If set to :obj:`None`, the size is tried to
get automatically inferrred. (default: :obj:`None`)
**kwargs: Any additional data which is needed to construct messages
and to update node embeddings.
"""
size = [None, None] if size is None else list(size)
assert len(size) == 2
i, j = (0, 1) if self.flow == "target_to_source" else (1, 0)
ij = {"_i": i, "_j": j}
message_args = []
for arg in self.__message_args__:
if arg[-2:] in ij.keys():
tmp = kwargs.get(arg[:-2], None)
if tmp is None: # pragma: no cover
message_args.append(tmp)
else:
idx = ij[arg[-2:]]
if isinstance(tmp, tuple) or isinstance(tmp, list):
assert len(tmp) == 2
if tmp[1 - idx] is not None:
if size[1 - idx] is None:
size[1 - idx] = tmp[1 - idx].size(0)
if size[1 - idx] != tmp[1 - idx].size(0):
raise ValueError(__size_error_msg__)
tmp = tmp[idx]
if size[idx] is None:
size[idx] = tmp.size(0)
if size[idx] != tmp.size(0):
raise ValueError(__size_error_msg__)
tmp = torch.index_select(tmp, 0, edge_index[idx])
message_args.append(tmp)
else:
message_args.append(kwargs.get(arg, None))
size[0] = size[1] if size[0] is None else size[0]
size[1] = size[0] if size[1] is None else size[1]
kwargs["edge_index"] = edge_index
kwargs["size"] = size
for (idx, arg) in self.__special_args__:
if arg[-2:] in ij.keys():
message_args.insert(idx, kwargs[arg[:-2]][ij[arg[-2:]]])
else:
message_args.insert(idx, kwargs[arg])
update_args = [kwargs[arg] for arg in self.__update_args__]
out = self.message(*message_args)
out = scatter_(self.aggr, out, edge_index[i], dim_size=size[i])
out = self.update(out, *update_args)
return out