def get_pretrained()

in src/mlm/models/__init__.py [0:0]


def get_pretrained(ctxs: List[mx.Context], name: str = 'bert-base-en-uncased', params_file: Optional[Path] = None, cased: bool = False, finetune: bool = False, regression: bool = False, freeze: int = 0, root: Optional[Path] = None) -> Tuple[Block, nlp.Vocab, nlp.data.BERTTokenizer]:

    if name not in SUPPORTED:
        logging.warn("Model '{}' not recognized as an MXNet model; treating as PyTorch model".format(name))
        model_fullname = name
        model_name = model_fullname.split('/')[-1]

        if model_name.startswith('albert-'):

            if params_file is None:
                model, loading_info = AlbertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True)
            else:
                model, loading_info = AlbertForMaskedLMOptimized.from_pretrained(params_file, output_loading_info=True)

            tokenizer = transformers.AlbertTokenizer.from_pretrained(model_fullname)
            vocab = None

        elif model_name.startswith('bert-'):

            if params_file is None:
                model, loading_info = BertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True)
            else:
                model, loading_info = BertForMaskedLMOptimized.from_pretrained(params_file, output_loading_info=True)

            tokenizer = transformers.BertTokenizer.from_pretrained(model_fullname)
            vocab = None

        elif model_name.startswith('distilbert-'):

            if params_file is None:
                model, loading_info = DistilBertForMaskedLMOptimized.from_pretrained(model_fullname, output_loading_info=True)
            else:
                model, loading_info = DistilBertForMaskedLMOptimized.from_pretrained(params_file, output_loading_info=True)

            tokenizer = transformers.DistilBertTokenizer.from_pretrained(model_fullname)
            vocab = None

        elif model_name.startswith('xlm-'):

            model, loading_info = transformers.XLMWithLMHeadModel.from_pretrained(model_fullname, output_loading_info=True)
            tokenizer = transformers.XLMTokenizer.from_pretrained(model_fullname)
            vocab = None

            # TODO: Not needed in transformers v3? Will vet.
            #
            # # TODO: The loading code in `transformers` assumes pred_layer is under transformers, so the LM head is not loaded properly. We load manually:
            # archive_file = transformers.XLMWithLMHeadModel.pretrained_model_archive_map[model_fullname]
            # resolved_archive_file = transformers.file_utils.cached_path(archive_file)
            # pretrained_state_dict = torch.load(resolved_archive_file, map_location='cpu')
            # new_state_dict = model.state_dict()
            # new_state_dict.update(
            #     {
            #         'pred_layer.proj.weight': pretrained_state_dict['pred_layer.proj.weight'],
            #         'pred_layer.proj.bias': pretrained_state_dict['pred_layer.proj.bias']
            #     }
            # )
            # model.load_state_dict(new_state_dict)

        else:
            raise ValueError("Model '{}' is not currently a supported PyTorch model".format(name))

    # Name format: model-size-lang-cased/uncased(-dataset / special characteristic)
    # e.g., 'bert-base-en-uncased-owt', 'gpt2-117m-en-cased'
    else:
        name_parts = name.split('-')
        model_name = name_parts[0]
        size = name_parts[1]
        lang = name_parts[2]
        if name_parts[3] == 'cased':
            cased = True
        elif name_parts[3] == 'uncased':
            cased = False
        dataset = name_parts[4] if len(name_parts) == 5 else None

        if freeze < 0:
            raise ValueError("# of initial layers to freeze must be non-negative")

        if params_file is not None and dataset is not None:
            logging.warning("Model parameters '{}' was provided, ignoring dataset suffix '{}'".format(params_file, dataset))

        if model_name == 'bert'and size != 'base_bertpr':

            if cased:
                dataset_suffix = '_cased'
            else:
                dataset_suffix = '_uncased'

            if size == 'base':
                model_fullname = 'bert_12_768_12'
            elif size == 'large':
                model_fullname = 'bert_24_1024_16'

            if lang == 'en':
                if dataset is None:
                    dataset_prefix = 'book_corpus_wiki_en'
                elif dataset == 'owt':
                    dataset_prefix = 'openwebtext_book_corpus_wiki_en'
            elif lang == 'multi':
                dataset_prefix = 'wiki_multilingual'

            # Get stock BERT with MLM outputs
            kwargs = {
                'dataset_name': dataset_prefix + dataset_suffix,
                'pretrained': True,
                'ctx': ctxs,
                'use_pooler': False,
                'use_decoder': False,
                'use_classifier': False
            }
            if finetune or regression:
                kwargs['use_pooler'] = True
            else:
                kwargs['use_decoder'] = True
            # Override GluonNLP's default location?
            if root is not None:
                kwargs['root'] = str(root)
            model, vocab = get_model(model_fullname, **kwargs)

            # Freeze initial layers if needed
            for i in range(freeze):
                model.encoder.transformer_cells[i].collect_params().setattr('grad_req', 'null')

            # Wrapper if appropriate
            if regression:
                # NOTE THIS:
                model = BERTRegression(model, dropout=0.1)
                model.regression.initialize(init=mx.init.Normal(1.0), ctx=ctxs)

            # MXNet warning message suggests this when softmaxing in float16
            # But float16 is buggy, so let's halve our inference speed for now :(
            # os.environ['MXNET_SAFE_ACCUMULATION'] = '1'
            # model.cast('float16')

            # Get tokenizer
            tokenizer = nlp.data.BERTTokenizer(vocab, lower=(not cased))

        elif model_name == 'roberta':

            if cased:
                dataset_suffix = '_cased'
            else:
                ValueError('Uncased not supported')

            if size == 'base':
                model_fullname = 'roberta_12_768_12'
            elif size == 'large':
                model_fullname = 'roberta_24_1024_16'

            if lang == 'en' and dataset is None:
                dataset_prefix = 'openwebtext_ccnews_stories_books'
            else:
                ValueError('Dataset not supported')

            # Get stock BERT with MLM outputs
            kwargs = {
                'dataset_name': dataset_prefix + dataset_suffix,
                'pretrained': True,
                'ctx': ctxs,
                'use_pooler': False,
                'use_decoder': False,
                'use_classifier': False
            }
            if finetune or regression:
                kwargs['use_pooler'] = True
            else:
                kwargs['use_decoder'] = True
            # Override GluonNLP's default location?
            if root is not None:
                kwargs['root'] = str(root)
            model, vocab = get_model(model_fullname, **kwargs)

            # Freeze initial layers if needed
            for i in range(freeze):
                model.encoder.transformer_cells[i].collect_params().setattr('grad_req', 'null')

            # Wrapper if appropriate
            if regression:
                ValueError("Not yet tested")
                # NOTE THIS:
                model = BERTRegression(model, dropout=0.1)
                model.regression.initialize(init=mx.init.Normal(1.0), ctx=ctxs)

            # Get tokenizer
            tokenizer = nlp.data.GPT2BPETokenizer()

            # TODO: Have the scorers condition on what the vocab and tokenizer class are
            vocab.cls_token = vocab.bos_token
            vocab.sep_token = vocab.eos_token
            tokenizer.convert_tokens_to_ids = vocab.to_indices

        elif model_name == 'gpt2':

            assert cased
            assert not finetune
            assert not regression
            assert freeze == 0

            if size == '117m':
                model_fullname = 'gpt2_117m'
            elif size == '345m':
                model_fullname = 'gpt2_345m'

            # Get stock GPT-2
            kwargs = {
                'dataset_name': 'openai_webtext',
                'pretrained': True,
                'ctx': ctxs,
            }
            # Override GluonNLP's default location?
            if root is not None:
                kwargs['root'] = str(root)

            model, vocab = get_model(model_fullname, **kwargs)

            # Get tokenizer
            tokenizer = nlp.data.GPT2BPETokenizer()
            # To fit the assumptions of score block
            tokenizer.vocab = vocab
            vocab.cls_token = vocab.eos_token
            vocab.sep_token = vocab.eos_token
            tokenizer.convert_tokens_to_ids = vocab.to_indices

        if params_file is not None:
            model.load_parameters(str(params_file),
                ctx=ctxs, allow_missing=True, ignore_extra=True, cast_dtype=True)

    return model, vocab, tokenizer