def check_model_type()

in optimum/graphcore/pipelines/__init__.py [0:0]


def check_model_type(self, supported_models: Union[List[str], dict]):
    """
    Check if the model class is supported by the pipeline.

    Args:
        supported_models (`List[str]` or `dict`):
            The list of models supported by the pipeline, or a dictionary with model class values.
    """
    if not isinstance(supported_models, list):  # Create from a model mapping
        supported_models_names = []
        for config, model in supported_models.items():
            # Mapping can now contain tuples of models for the same configuration.
            if isinstance(model, tuple):
                supported_models_names.extend([_model.__name__ for _model in model])
            else:
                supported_models_names.append(model.__name__)
        supported_models = supported_models_names

    if isinstance(self.model, poptorch.PoplarExecutor):
        model_class_name = self.model._user_model.__class__.__bases__[0].__name__
    elif isinstance(self.model, IPUGenerationMixin):
        model_class_name = self.model.__class__.__bases__[0].__name__
    else:
        model_class_name = self.model.__class__.__name__

    if model_class_name not in supported_models:
        logger.error(
            f"The model '{model_class_name}' is not supported for {self.task}. Supported models are"
            f" {supported_models}."
        )