def query_target_subgraph()

in src/lambda.d/inference/func/inferenceApi.py [0:0]


    def query_target_subgraph(self, target_id, tr_dict, transaction_value_cols, union_id_cols, dummied_col):
        """Extract 2nd degree subgraph of target transaction.Dump data into subgraph dict and n_feats dict.
        subgraph_dict:  related transactions' id list and values through edges
        n_feats dict: related 1 degree vertex and transactions' embeded elements vectors. 
        Usually after insert new test sample's vertex and edges into graphDB. 
        
        Example:
        >>> query_target_subgraph('3661635', load_data_from_event(), 'M2_T,M3_F,M3_T,...')
        """
        subgraph_dict = {}
        neighbor_list = []
        neighbor_dict = {}
        transaction_embed_value_dict = {}
        
        ii = 0
        s_t = dt.now()
        
        conn = self.gremlin_utils.remote_connection()
        g = self.gremlin_utils.traversal_source(connection=conn) 

        target_name = target_id[(target_id.find('-')+1):]
        feature_list = g.V().has(id,target_id).out().id().toList()
        for feat in feature_list:
            ii += 1
            feat_name = feat[:feat.find('-')]
            feat_value = feat[(feat.find('-')+1):]
            node_list = g.V().has(id,feat).both().limit(MAX_FEATURE_NODE).id().toList()
            target_and_conn_node_list = [int(target_name)]+[int(target_conn_node[(target_conn_node.find('-')+1):]) for target_conn_node in node_list]
            target_and_conn_node_list = list(set(target_and_conn_node_list))
            neighbor_list += target_and_conn_node_list
            nodes_and_feature_value_array = (target_and_conn_node_list,[feat_value]*len(target_and_conn_node_list))
            subgraph_dict['target<>'+feat_name] = nodes_and_feature_value_array
        
        e_t = dt.now()
        logger.info(f'INSIDE query_target_subgraph: subgraph_dict used {(e_t - s_t).total_seconds()} seconds')
        new_s_t = e_t

        union_li = [__.V().has(id,target_id).both().hasLabel(label).both().limit(MAX_FEATURE_NODE) for label in union_id_cols]

        if len(union_id_cols) == 51:
            node_dict = g.V().has(id,target_id).union(__.both().hasLabel('card1').both().limit(MAX_FEATURE_NODE),\
                    union_li[1], union_li[2], union_li[3], union_li[4], union_li[5],\
                    union_li[6], union_li[7], union_li[8], union_li[9], union_li[10],\
                    union_li[11], union_li[12], union_li[13], union_li[14], union_li[15],\
                    union_li[16], union_li[17], union_li[18], union_li[19], union_li[20],\
                    union_li[21], union_li[22], union_li[23], union_li[24], union_li[25],\
                    union_li[26], union_li[27], union_li[28], union_li[29], union_li[30],\
                    union_li[31], union_li[32], union_li[33], union_li[34], union_li[35],\
                    union_li[36], union_li[37], union_li[38], union_li[39], union_li[40],\
                    union_li[41], union_li[42], union_li[43], union_li[44], union_li[45],\
                    union_li[46], union_li[47], union_li[48], union_li[49], union_li[50]).elementMap().toList()
        else:
            node_dict = g.V().has(id,target_id).union(__.both().hasLabel('card1').both().limit(MAX_FEATURE_NODE),\
                    union_li[1], union_li[2], union_li[3], union_li[4], union_li[5],\
                    union_li[6], union_li[7], union_li[8], union_li[9], union_li[10]).elementMap().toList()

        e_t = dt.now()
        logger.info(f'INSIDE query_target_subgraph: node_dict used {(e_t - new_s_t).total_seconds()} seconds.')
        new_s_t = e_t

        logger.debug(f'Found {len(node_dict)} nodes from graph dbs...')
        
        class Item():
            def __init__(self, item):
                self.item = item
        
            def __hash__(self):
                return hash(self.item.get(list(self.item)[0]))
        
            def __eq__(self,other):
                if isinstance(other, self.__class__):
                    return self.__hash__() == other.__hash__()
                else:
                    return NotImplemented
        
            def __repr__(self):
                return "Item(%s)" % (self.item)
                
        node_dict = list(set([Item(node) for node in node_dict]))
        logger.debug(f'Found {len(node_dict)} nodes without duplication')
        for item in node_dict:
            item = item.item
            node = item.get(list(item)[0])
            node_value = node[(node.find('-')+1):]
            try:
                logger.debug(f'the props of node {node} is {item.get(attr_version_key)}')
                jsonVal = json.loads(item.get(attr_version_key))
                neighbor_dict[node_value] = [jsonVal[key] for key in transaction_value_cols]
                logger.debug(f'neighbor pair is {node_value}, {neighbor_dict[node_value]}')
            except json.JSONDecodeError:
                logger.warn(f'Malform node value {node} is {item.get(attr_version_key)}, run below cmd to remove it')
                logger.info(f'g.V(\'{node}\').drop()')

        target_value = target_id[(target_id.find('-')+1):]
        jsonVal = json.loads(tr_dict[0].get(attr_version_key))
        neighbor_dict[target_value] = [jsonVal[key] for key in transaction_value_cols]
        
        logger.info(f'INSIDE query_target_subgraph: neighbor_dict used {(e_t - new_s_t).total_seconds()} seconds.')

        attr_cols = ['val'+str(x) for x in range(1,391)]
        for attr in feature_list:
            attr_name = attr[:attr.find('-')]
            attr_value = attr[(attr.find('-')+1):]
            attr_dict = g.V().has(id,attr).valueMap().toList()[0]
            logger.debug(f'attr is {attr}, dict is {attr_dict}')
            jsonVal = json.loads(attr_dict.get(attr_version_key)[0])
            attr_dict = [float(jsonVal[key]) for key in attr_cols]
            attr_input_dict = {}
            attr_input_dict[attr_value] = attr_dict
            transaction_embed_value_dict[attr_name] = attr_input_dict
        
        e_t = dt.now()
        logger.info(f'INSIDE query_target_subgraph: transaction_embed_value_dict used {(e_t - new_s_t).total_seconds()} seconds. Total test cost {(e_t - s_t).total_seconds()} seconds.')
        new_s_t = e_t
        
        transaction_embed_value_dict['target'] = neighbor_dict

        conn.close()   

        return subgraph_dict, transaction_embed_value_dict