in src/graph_notebook/notebooks/03-Neptune-ML/03-Sample-Applications/04-Telco-Networks/neptune_ml_utils.py [0:0]
def get_node_to_idx_mapping(training_job_name: str = None, dataprocessing_job_name: str = None,
model_artifacts_location: str = './model-artifacts', vertex_label: str = None):
assert training_job_name is not None or dataprocessing_job_name is not None, \
"You must provide either a modeltraining job id or a dataprocessing job id to obtain node to index mappings"
job_name = training_job_name if training_job_name is not None else dataprocessing_job_name
job_type = "modeltraining" if training_job_name == job_name else "dataprocessing"
filename = "mapping.info" if training_job_name == job_name else "info.pkl"
mapping_key = "node2id" if training_job_name == job_name else "node_id_map"
# get mappings
model_artifacts_location = os.path.join(model_artifacts_location, job_name)
if not os.path.exists(os.path.join(model_artifacts_location, filename)):
job_s3_output = get_neptune_ml_job_output_location(job_name, job_type)
print(job_s3_output)
if not job_s3_output:
return
S3Downloader.download(os.path.join(job_s3_output, filename), model_artifacts_location)
with open(os.path.join(model_artifacts_location, filename), "rb") as f:
mapping = pickle.load(f)[mapping_key]
if vertex_label is not None:
if vertex_label in mapping:
mapping = mapping[vertex_label]
else:
print("Mapping for vertex label: {} not found.".format(vertex_label))
print("valid vertex labels which have vertices mapped to embeddings: {} ".format(list(mapping.keys())))
print("Returning mapping for all valid vertex labels")
return mapping