def _validate_column_mapping()

in src/setfit/trainer.py [0:0]


    def _validate_column_mapping(self, dataset: "Dataset") -> None:
        """
        Validates the provided column mapping against the dataset.
        """
        column_names = set(dataset.column_names)
        if self.column_mapping is None and not self._REQUIRED_COLUMNS.issubset(column_names):
            # Issue #226: load_dataset will automatically assign points to "train" if no split is specified
            if column_names == {"train"} and isinstance(dataset, DatasetDict):
                raise ValueError(
                    "SetFit expected a Dataset, but it got a DatasetDict with the split ['train']. "
                    "Did you mean to select the training split with dataset['train']?"
                )
            elif isinstance(dataset, DatasetDict):
                raise ValueError(
                    f"SetFit expected a Dataset, but it got a DatasetDict with the splits {sorted(column_names)}. "
                    "Did you mean to select one of these splits from the dataset?"
                )
            else:
                raise ValueError(
                    f"SetFit expected the dataset to have the columns {sorted(self._REQUIRED_COLUMNS)}, "
                    f"but only the columns {sorted(column_names)} were found. "
                    "Either make sure these columns are present, or specify which columns to use with column_mapping in Trainer."
                )
        if self.column_mapping is not None:
            missing_columns = set(self._REQUIRED_COLUMNS)
            # Remove columns that will be provided via the column mapping
            missing_columns -= set(self.column_mapping.values())
            # Remove columns that will be provided because they are in the dataset & not mapped away
            missing_columns -= set(dataset.column_names) - set(self.column_mapping.keys())
            if missing_columns:
                raise ValueError(
                    f"The following columns are missing from the column mapping: {missing_columns}. "
                    "Please provide a mapping for all required columns."
                )
            if not set(self.column_mapping.keys()).issubset(column_names):
                raise ValueError(
                    f"The column mapping expected the columns {sorted(self.column_mapping.keys())} in the dataset, "
                    f"but the dataset had the columns {sorted(column_names)}."
                )