def get_node_to_idx_mapping()

in src/graph_notebook/notebooks/03-Neptune-ML/02-SPARQL/neptune_ml_sparql_utils.py [0:0]


def get_node_to_idx_mapping(training_job_name: str = None, dataprocessing_job_name: str = None,
                            model_artifacts_location: str = './model-artifacts', vertex_label: str = None):
    assert training_job_name is not None or dataprocessing_job_name is not None, \
        "You must provide either a modeltraining job id or a dataprocessing job id to obtain node to index mappings"

    job_name = training_job_name if training_job_name is not None else dataprocessing_job_name
    job_type = "modeltraining" if training_job_name == job_name else "dataprocessing"
    filename = "mapping.info" if training_job_name == job_name else "info.pkl"
    mapping_key = "node2id" if training_job_name == job_name else "node_id_map"

    # get mappings
    model_artifacts_location = os.path.join(model_artifacts_location, job_name)
    if not os.path.exists(os.path.join(model_artifacts_location, filename)):
        job_s3_output = get_neptune_ml_job_output_location(job_name, job_type)
        print(job_s3_output)
        if not job_s3_output:
            return
        S3Downloader.download(os.path.join(job_s3_output, filename), model_artifacts_location)

    with open(os.path.join(model_artifacts_location, filename), "rb") as f:
        mapping = pickle.load(f)[mapping_key]
        if vertex_label is not None:
            if vertex_label in mapping:
                mapping = mapping[vertex_label]
            else:
                print("Mapping for vertex label: {} not found.".format(vertex_label))
                print("valid vertex labels which have vertices mapped to embeddings: {} ".format(list(mapping.keys())))
                print("Returning mapping for all valid vertex labels")

    return mapping