def process_graph_feats()

in anticipation/anticipation/datasets/epic_future_labels.py [0:0]


    def process_graph_feats(self, graph, trunk_features, future_labels):
        graph = copy.deepcopy(graph)
        # Drop useless visits (VERY IMPORTANT FOR GTEA's BLACK FRAMES !!!!!)
        keep = torch.ones((len(graph.nodes())))
        node_feats, node_length = self.get_node_feats(graph, trunk_features)
        if self.dset=='gtea':
            for i, node in enumerate(sorted(graph.nodes())):
                visits = graph.node[node]['members']
                if len(visits)==1 and visits[0]['stop'][1]-visits[0]['start'][1]<self.fps:
                    # graph.remove_node(node)
                    keep[i] = 0

        if not self.test_mode and self.graph_aug:
            keep = self.graph_augmentation(graph, keep)
        
        nodes = sorted(graph.nodes())
        for i in range(keep.shape[0]):
            if keep[i] == 0:
                graph.remove_node(nodes[i])
        # -------------------------------------------------------------------#
        # Make the dgl graph now
        nodes = sorted(graph.nodes())
        node_to_idx = {node: idx for idx, node in enumerate(nodes)}
        src, dst = [], []
        if len(graph.edges()) > 0:
            src, dst = zip(*graph.edges())
            src = [node_to_idx[node] for node in src]
            dst = [node_to_idx[node] for node in dst]

        g = dgl.DGLGraph()
        g.add_nodes(len(nodes))
        g.add_edges(src, dst)
        g.add_edges(dst, src)  # undirected
        g.add_edges(g.nodes(), g.nodes())  # add self loops

        g.ndata['feats'] = node_feats[keep == 1]
        g.ndata['length'] = node_length[keep == 1]

        if self.label=='int':
            g.ndata['labels'] = future_labels['ints'][keep == 1]
        elif self.label=='noun':
            g.ndata['labels'] = future_labels['nouns'][keep == 1]
        elif self.label == 'verb':
            g.ndata['labels'] = future_labels['verbs'][keep == 1]

        cur_status = torch.zeros(len(nodes))
        cur_node = epic_utils.find_last_visit_node(graph)
        cur_status[node_to_idx[cur_node]] = 1
        g.ndata['cur_status'] = cur_status

        nbhs = nx.ego_graph(graph, cur_node, radius=2, center=False).nodes()
        for nbh in nbhs:
            cur_status[node_to_idx[nbh]] = 2
        g.ndata['cur_status'] = cur_status

        return g