tools/ml-auto-eda/ml_eda/orchestration/analysis_tracker.py (86 lines of code) (raw):

# Copyright 2019 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tracker for storing the performed analysis results""" from __future__ import absolute_import from __future__ import print_function from typing import Set, List, Text, Dict from collections import defaultdict from google.protobuf.json_format import MessageToDict from ml_eda.proto import analysis_entity_pb2 from ml_eda.job_config_util.job_config import JobConfig NAME_SEP = '-' Analysis = analysis_entity_pb2.Analysis Attribute = analysis_entity_pb2.Attribute def _analysis_unique_name(analysis: Analysis) -> Text: """The unique name of an analysis is: [analysis_name]_[attribute_one]_[attribute_two]_ ...... Args: analysis: An instance of analysis_entity_pb2.Analysis Returns: Unique name of an Analysis in string """ attributes = analysis.features # analysis_entity_pb2.Analysis.Name is an integer, need the following # conversion to get its string value analysis_name = Analysis.Name.Name(analysis.name) attribute_names = [att.name for att in attributes] return NAME_SEP.join([analysis_name] + sorted(attribute_names)) class AttributeAnalysisTracker: """ Tracker for storing the performed analysis results related to one particular attribute """ def __init__(self, att_name: Text, att_type: Text): """ Args: att_name: (string), attribute name att_type: (string), attribute type """ self.att_type = att_type self.att_name = att_name self.tracker = dict() def add_analysis(self, analysis: Analysis): """Add an analysis result to attribute tracker Args: analysis: (analysis_entity_pb2.Analysis), performed analysis result """ analysis_unique_name = _analysis_unique_name(analysis) self.tracker[analysis_unique_name] = analysis def get_analysis(self, analysis_name: Text) -> List[Analysis]: """Return all the analysis related to the attribute given an analysis name. Since one attribute can run the same analysis with multiple attributes, this function would return an iterator. Args: analysis_name: (string), name of the analysis specified in the proto Returns: List[analysis_entity_pb2.Analysis] """ analysis = [] for item in self.tracker: if item.startswith(analysis_name): analysis.append(self.tracker[item]) return analysis def get_all_analysis(self) -> List[Analysis]: """Return all the analysis stored in the attribute tracker""" return list(self.tracker.values()) class AnalysisTracker: """ Tracker for storing the performed analysis results """ def __init__(self, job_config: JobConfig): self._job_config = job_config # tracker for tracking all the analysis self._analysis_tracker = dict() # tracker for tracking all attributes having analysis performed self._attribute_tracker = dict() # tracker for attributes in different type self._attribute_type_tracker = defaultdict(set) def add_analysis(self, analysis: Analysis): """Add analysis to two trackers Args: analysis: (analysis_entity_pb2.Analysis) """ analysis_attributes = analysis.features # Get the unique name for the analysis analysis_unique_name = _analysis_unique_name(analysis) # Add analysis to analysis_tracker self._analysis_tracker[analysis_unique_name] = analysis # Add analysis to attribute_tracker for attr in analysis_attributes: if attr.name not in self._attribute_tracker: self._attribute_tracker[attr.name] = AttributeAnalysisTracker( att_name=attr.name, att_type=attr.type) self._attribute_type_tracker[attr.type].add(attr.name) self._attribute_tracker[attr.name].add_analysis(analysis) def get_job_config(self) -> JobConfig: """Get the job config""" return self._job_config def get_target_attribute(self) -> Attribute: """Get the target attribute""" return self._job_config.target_column def get_attribute_names(self) -> Set[Text]: """Get the names of all the involved attributes""" return set(self._attribute_tracker.keys()) def get_num_attribute_names(self) -> Set[Text]: """Get the names of all numerical attributes""" return self._attribute_type_tracker[Attribute.NUMERICAL] def get_cat_attribute_names(self) -> Set[Text]: """Get the names of all categorical attributes""" return self._attribute_type_tracker[Attribute.CATEGORICAL] def get_all_analysis(self) -> List[Analysis]: """Get all the stored analyses results""" return list(self._analysis_tracker.values()) def get_all_analysis_unique_names(self) -> List[Text]: """Get the unique name of all stored analyses""" return list(self._analysis_tracker.keys()) def get_analysis_by_attribute(self, attribute_name: Text) -> List[Analysis]: """Get all the analyses given attribute name""" if attribute_name in self._attribute_tracker: return self._attribute_tracker[attribute_name].get_all_analysis() return [] def get_analysis_by_name(self, analysis_name: Text) -> List[Analysis]: """Get all the analyses given analysis name""" analysis = [] for item in self._analysis_tracker: if item.startswith(analysis_name): analysis.append(self._analysis_tracker[item]) return analysis def get_analysis_by_attribute_and_name(self, attribute_name: Text, analysis_name: Text ) -> List[Analysis]: """Get all the analyses given attribute name and analysis name""" if attribute_name in self._attribute_tracker: return self._attribute_tracker[attribute_name].get_analysis( analysis_name) return [] def export_to_dict(self) -> Dict[Text, Dict]: """Export all analysis in a dictionary, where the Analysis object is serialized.""" export_dict = {} for analysis_name in self._analysis_tracker: analysis_str = MessageToDict(self._analysis_tracker[analysis_name]) export_dict[analysis_name] = analysis_str return export_dict