in anticipation/anticipation/datasets/epic_future_labels.py [0:0]
def process_graph_feats(self, graph, trunk_features, future_labels):
graph = copy.deepcopy(graph)
# Drop useless visits (VERY IMPORTANT FOR GTEA's BLACK FRAMES !!!!!)
keep = torch.ones((len(graph.nodes())))
node_feats, node_length = self.get_node_feats(graph, trunk_features)
if self.dset=='gtea':
for i, node in enumerate(sorted(graph.nodes())):
visits = graph.node[node]['members']
if len(visits)==1 and visits[0]['stop'][1]-visits[0]['start'][1]<self.fps:
# graph.remove_node(node)
keep[i] = 0
if not self.test_mode and self.graph_aug:
keep = self.graph_augmentation(graph, keep)
nodes = sorted(graph.nodes())
for i in range(keep.shape[0]):
if keep[i] == 0:
graph.remove_node(nodes[i])
# -------------------------------------------------------------------#
# Make the dgl graph now
nodes = sorted(graph.nodes())
node_to_idx = {node: idx for idx, node in enumerate(nodes)}
src, dst = [], []
if len(graph.edges()) > 0:
src, dst = zip(*graph.edges())
src = [node_to_idx[node] for node in src]
dst = [node_to_idx[node] for node in dst]
g = dgl.DGLGraph()
g.add_nodes(len(nodes))
g.add_edges(src, dst)
g.add_edges(dst, src) # undirected
g.add_edges(g.nodes(), g.nodes()) # add self loops
g.ndata['feats'] = node_feats[keep == 1]
g.ndata['length'] = node_length[keep == 1]
if self.label=='int':
g.ndata['labels'] = future_labels['ints'][keep == 1]
elif self.label=='noun':
g.ndata['labels'] = future_labels['nouns'][keep == 1]
elif self.label == 'verb':
g.ndata['labels'] = future_labels['verbs'][keep == 1]
cur_status = torch.zeros(len(nodes))
cur_node = epic_utils.find_last_visit_node(graph)
cur_status[node_to_idx[cur_node]] = 1
g.ndata['cur_status'] = cur_status
nbhs = nx.ego_graph(graph, cur_node, radius=2, center=False).nodes()
for nbh in nbhs:
cur_status[node_to_idx[nbh]] = 2
g.ndata['cur_status'] = cur_status
return g