def from_pretrained()

in tools/scripts/features/frcnn/modeling_frcnn.py [0:0]


    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        config = kwargs.pop("config", None)
        state_dict = kwargs.pop("state_dict", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_tf = kwargs.pop("from_tf", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)

        # Load config if we don't provide a configuration
        if not isinstance(config, Config):
            config_path = (
                config if config is not None else pretrained_model_name_or_path
            )
            # try:
            config = Config.from_pretrained(
                config_path,
                cache_dir=cache_dir,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
            )

        # Load model
        if pretrained_model_name_or_path is not None:
            if os.path.isdir(pretrained_model_name_or_path):
                if os.path.isfile(
                    os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
                ):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(
                        pretrained_model_name_or_path, WEIGHTS_NAME
                    )
                else:
                    raise OSError(
                        "Error no file named {} found in directory {} ".format(
                            WEIGHTS_NAME, pretrained_model_name_or_path
                        )
                    )
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
                pretrained_model_name_or_path
            ):
                archive_file = pretrained_model_name_or_path
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                assert (
                    from_tf
                ), "We found a TensorFlow checkpoint at {}, please set from_tf \
                to True to load from this checkpoint".format(
                    pretrained_model_name_or_path + ".index"
                )
                archive_file = pretrained_model_name_or_path + ".index"

            try:
                # Load from URL or cache if already cached
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                )
                if resolved_archive_file is None:
                    raise OSError
            except OSError:
                msg = f"Can't load weights for '{pretrained_model_name_or_path}'."
                raise OSError(msg)

            if resolved_archive_file == archive_file:
                print(f"loading weights file {archive_file}")
            else:
                print(
                    f"loading weights file {archive_file} from cache at "
                    + f"{resolved_archive_file}"
                )
        else:
            resolved_archive_file = None

        # Instantiate model.
        model = cls(config)

        if state_dict is None:
            try:
                try:
                    state_dict = torch.load(resolved_archive_file, map_location="cpu")
                except Exception:
                    state_dict = load_checkpoint(resolved_archive_file)

            except Exception:
                raise OSError(
                    "Unable to load weights from pytorch checkpoint file. "
                    "If you tried to load a PyTorch model from a TF 2.0 "
                    + "checkpoint, please set from_tf=True. "
                )

        missing_keys = []
        unexpected_keys = []
        error_msgs = []

        # Convert old format to new format if needed from a PyTorch state_dict
        old_keys = []
        new_keys = []
        for key in state_dict.keys():
            new_key = None
            if "gamma" in key:
                new_key = key.replace("gamma", "weight")
            if "beta" in key:
                new_key = key.replace("beta", "bias")
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)

        # copy state_dict so _load_from_state_dict can modify it
        metadata = getattr(state_dict, "_metadata", None)
        state_dict = state_dict.copy()
        if metadata is not None:
            state_dict._metadata = metadata

        model_to_load = model
        model_to_load.load_state_dict(state_dict)

        if model.__class__.__name__ != model_to_load.__class__.__name__:
            base_model_state_dict = model_to_load.state_dict().keys()
            head_model_state_dict_without_base_prefix = [
                key.split(cls.base_model_prefix + ".")[-1]
                for key in model.state_dict().keys()
            ]
            missing_keys.extend(
                head_model_state_dict_without_base_prefix - base_model_state_dict
            )

        if len(unexpected_keys) > 0:
            print(
                "Some weights of the model checkpoint at "
                f"{pretrained_model_name_or_path} were not used when "
                f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
                f"- This IS expected if you are initializing "
                f"{model.__class__.__name__} from the checkpoint of a model trained "
                "on another task or with another architecture (e.g. initializing a "
                "BertForSequenceClassification model from a BertForPreTraining model)."
                "\n- This IS NOT expected if you are initializing "
                f"{model.__class__.__name__} from the checkpoint of a model "
                "that you expect to be exactly identical (initializing a "
                "BertForSequenceClassification model from a "
                "BertForSequenceClassification model)."
            )
        else:
            print(
                f"All model checkpoint weights were used when initializing "
                + f"{model.__class__.__name__}.\n"
            )
        if len(missing_keys) > 0:
            print(
                f"Some weights of {model.__class__.__name__} were not initialized "
                + f"from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized: {missing_keys}\n"
                f"You should probably TRAIN this model on a down-stream task to be "
                + "able to use it for predictions and inference."
            )
        else:
            print(
                f"All the weights of {model.__class__.__name__} were initialized "
                + f"from the model checkpoint at {pretrained_model_name_or_path}.\n"
                f"If your task is similar to the task the model of the "
                + "checkpoint was trained on, "
                f"you can already use {model.__class__.__name__} for "
                + "predictions without further training."
            )
        if len(error_msgs) > 0:
            raise RuntimeError(
                "Error(s) in loading state_dict for {}:\n\t{}".format(
                    model.__class__.__name__, "\n\t".join(error_msgs)
                )
            )
        # Set model in evaluation mode to deactivate DropOut modules by default
        model.eval()

        return model