tools/ml-auto-eda/ml_eda/job_config_util/job_config.py (121 lines of code) (raw):
# Copyright 2019 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Definition of utility class for holding the configuration of running
analysis"""
from __future__ import absolute_import
from __future__ import print_function
from typing import List
from ml_eda.constants import c
from ml_eda.proto import analysis_entity_pb2
class JobConfig:
"""Uility class for holding the configuration of running analysis"""
# pylint: disable-msg=too-many-instance-attributes
_datasource = analysis_entity_pb2.DataSource()
def __init__(self,
datasource_type: str,
datasource_location: str,
target_column: str,
numerical_attributes: List[str],
categorical_attributes: List[str],
analysis_run_ops,
analysis_run_config):
# pylint: disable-msg=too-many-arguments
if datasource_type == c.datasources.BIGQUERY:
self._datasource.type = analysis_entity_pb2.DataSource.BIGQUERY
elif datasource_type == c.datasources.CSV:
self._datasource.type = analysis_entity_pb2.DataSource.CSV
self._datasource.location = datasource_location
self._datasource.target.name = target_column
if target_column == c.schema.NULL:
self._ml_type = c.ml_type.NULL
elif target_column in numerical_attributes:
self._datasource.target.type = analysis_entity_pb2.Attribute.NUMERICAL
self._ml_type = c.ml_type.REGRESSION
elif target_column in categorical_attributes:
self._datasource.target.type = analysis_entity_pb2.Attribute.CATEGORICAL
self._ml_type = c.ml_type.CLASSIFICATION
else:
raise ValueError('The specified target column {} does not belong to'
'Categorical or Numerical features in the '
'job_config.ini'.format(target_column))
self._numerical_attributes = self._create_numerical_attributes(
numerical_attributes)
self._datasource.features.extend(self._numerical_attributes)
self._categorical_attributes = self._create_categorical_attributes(
categorical_attributes)
self._datasource.features.extend(self._categorical_attributes)
# This is for tracking categorical attributes with limited cardinality
self._categorical_low_card_attributes = self._categorical_attributes
# Running configuration
self._contingency_table_run = analysis_run_ops.getboolean(
c.analysis_run.CONTINGENCY_TABLE_RUN)
self._table_descriptive_run = analysis_run_ops.getboolean(
c.analysis_run.TABLE_DESCRIPTIVE_RUN)
self._pearson_corr_run = analysis_run_ops.getboolean(
c.analysis_run.PEARSON_CORRELATION_RUN)
self._information_gain_run = analysis_run_ops.getboolean(
c.analysis_run.INFORMATION_GAIN_RUN)
self._chi_square_run = analysis_run_ops.getboolean(
c.analysis_run.CHI_SQUARE_RUN)
self._anova_run = analysis_run_ops.getboolean(
c.analysis_run.ANOVA_RUN)
# Analysis configuration
self._histogram_bin = analysis_run_config.getint(
c.analysis_config.HISTOGRAM_BIN)
self._value_counts_limit = analysis_run_config.getint(
c.analysis_config.VALUE_COUNTS_LIMIT)
self._general_cardinality_limit = analysis_run_config.getint(
c.analysis_config.GENERAL_CARDINALITY_LIMIT)
@staticmethod
def _create_attributes(attribute_names: List[str],
attribute_type: int
) -> List[analysis_entity_pb2.Attribute]:
"""Construct analysis_entity_pb2.Attribute instance for attributes
Args:
attribute_names: (List[string]), name list of the attribute
attribute_type: (int), type of the attribute defined in the proto
Returns:
List[analysis_entity_pb2.Attribute]
"""
return [
analysis_entity_pb2.Attribute(name=name, type=attribute_type)
for name in attribute_names
]
def _create_numerical_attributes(self, attribute_names: List[str]
) -> List[analysis_entity_pb2.Attribute]:
"""Consturct analysis_entity_pb2.Attribute instance for numerical attributes
Args:
attribute_names: (List[string]), name list of the attributes
Returns:
List[analysis_entity_pb2.Attribute]
"""
return self._create_attributes(attribute_names,
analysis_entity_pb2.Attribute.NUMERICAL)
def _create_categorical_attributes(self, attribute_names: List[str]
) -> List[analysis_entity_pb2.Attribute]:
"""Construct analysis_entity_pb2.Attribute instance for cat attributes.
Args:
attribute_names: (List[string]), name list of the attributes
Returns:
List[analysis_entity_pb2.Attribute]
"""
return self._create_attributes(attribute_names,
analysis_entity_pb2.Attribute.CATEGORICAL)
def update_low_card_categorical(self, features):
"""Update low cardinality attributes"""
self._categorical_low_card_attributes = features
@property
def datasource(self):
# pylint: disable-msg=missing-docstring
return self._datasource
@property
def target_column(self):
# pylint: disable-msg=missing-docstring
return self._datasource.target
@property
def ml_type(self):
# pylint: disable-msg=missing-docstring
return self._ml_type
@property
def numerical_attributes(self):
# pylint: disable-msg=missing-docstring
return self._numerical_attributes
@property
def categorical_attributes(self):
# pylint: disable-msg=missing-docstring
return self._categorical_attributes
@property
def low_card_categorical_attributes(self):
# pylint: disable-msg=missing-docstring
return self._categorical_low_card_attributes
# Analysis Running Configuration
@property
def contingency_table_run(self):
# pylint: disable-msg=missing-docstring
return self._contingency_table_run
@property
def table_descriptive_run(self):
# pylint: disable-msg=missing-docstring
return self._table_descriptive_run
@property
def pearson_corr_run(self):
# pylint: disable-msg=missing-docstring
return self._pearson_corr_run
@property
def information_gain_run(self):
# pylint: disable-msg=missing-docstring
return self._information_gain_run
@property
def chi_square_run(self):
# pylint: disable-msg=missing-docstring
return self._chi_square_run
@property
def anova_run(self):
# pylint: disable-msg=missing-docstring
return self._anova_run
@property
def histogram_bin(self):
# pylint: disable-msg=missing-docstring
return self._histogram_bin
@property
def value_counts_limit(self):
# pylint: disable-msg=missing-docstring
return self._value_counts_limit
@property
def general_cardinality_limit(self):
# pylint: disable-msg=missing-docstring
return self._general_cardinality_limit