in src/sagemaker_sklearn_extension/externals/header.py [0:0]
def __init__(self, column_names: list, target_column_name: str):
"""
Parameters
----------
column_names : iterable of the column names in the order of occurrence
target_column_name : str, name of the target column
Raises
------
ValueError : target_column_name is not present in column_names or duplicate entries found in column_names
"""
self.target_column_index = None
self.target_column_name = target_column_name
# maintaining a dict{column_name: Indices}
self._column_name_indices = OrderedDict()
feature_index_offset = 0
duplicate_column_indices = defaultdict(list)
for i, column_name in enumerate(column_names):
# already seen the column, add to duplicate_column_indices
if column_name in self._column_name_indices:
duplicate_column_indices[column_name].append(i)
else:
self._column_name_indices[column_name] = Indices(column_index=i, feature_index=i - feature_index_offset)
# if it's target column, setup target_index and adjust the feature index
# offset for following features columns
if column_name == target_column_name:
self.target_column_index = i
feature_index_offset = 1
self._column_name_indices[column_name] = Indices(column_index=i, feature_index=None)
if self.target_column_index is None:
raise ValueError(
"Specified target column '{target_column_name}' is "
"not a valid column name.".format(target_column_name=target_column_name)
)
if duplicate_column_indices:
raise ValueError(
"Duplicate column names were found:\n{}".format(
"\n".join(
[
"{name} at index {index}".format(name=name, index=index)
for (name, index) in duplicate_column_indices.items()
]
)
)
)