def from_pretrained()

in eland/ml/pytorch/wrappers.py [0:0]


    def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]:
        config = AutoConfig.from_pretrained(model_id, token=token)

        def is_compatible() -> bool:
            is_dpr_model = config.model_type == "dpr"
            has_architectures = (
                config.architectures is not None and len(config.architectures) == 1
            )
            is_supported_architecture = has_architectures and (
                config.architectures[0] in _DPREncoderWrapper._SUPPORTED_MODELS_NAMES
            )
            return is_dpr_model and is_supported_architecture

        if is_compatible():
            model = getattr(transformers, config.architectures[0]).from_pretrained(
                model_id, torchscript=True
            )
            return _DPREncoderWrapper(model)
        else:
            return None