in source/sagemaker/sagemaker_graph_fraud_detection/dgl_fraud_detection/data.py [0:0]
def _get_mask(id_to_node, node_to_id, num_nodes, masked_nodes, additional_mask_rate):
"""
Returns the train and test mask arrays
:param id_to_node: dictionary mapping node names(id) to dgl node idx
:param node_to_id: dictionary mapping dgl node idx to node names(id)
:param num_nodes: number of user/account nodes in the graph
:param masked_nodes: list of nodes to be masked during training, nodes without labels
:param additional_mask_rate: float for additional masking of nodes with labels during training
:return: (list, list) train and test mask array
"""
train_mask = np.ones(num_nodes)
test_mask = np.zeros(num_nodes)
for node_id in masked_nodes:
train_mask[id_to_node[node_id]] = 0
test_mask[id_to_node[node_id]] = 1
if additional_mask_rate and additional_mask_rate < 1:
unmasked = np.array([idx for idx in range(num_nodes) if node_to_id[idx] not in masked_nodes])
yet_unmasked = np.random.permutation(unmasked)[:int(additional_mask_rate*num_nodes)]
train_mask[yet_unmasked] = 0
return train_mask, test_mask