in torchtext/models/roberta/bundler.py [0:0]
def get_model(self,
*,
head: Optional[Module] = None,
load_weights: bool = True,
freeze_encoder: bool = False,
dl_kwargs: Dict[str, Any] = None) -> RobertaModel:
r"""get_model(head: Optional[torch.nn.Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> torctext.models.RobertaModel
Args:
head (nn.Module): A module to be attached to the encoder to perform specific task. If provided, it will replace the default member head (Default: ``None``)
load_weights (bool): Indicates whether or not to load weights if available. (Default: ``True``)
freeze_encoder (bool): Indicates whether or not to freeze the encoder weights. (Default: ``False``)
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: ``None``)
"""
if load_weights:
assert self._path is not None, "load_weights cannot be True. The pre-trained model weights are not available for the current object"
if freeze_encoder:
if not load_weights or not self._path:
logger.warn("The encoder is not loaded with pre-trained weights. Setting freeze_encoder to True will hinder encoder from learning appropriate weights.")
if head is not None:
input_head = head
if self._head is not None:
logger.log("A custom head module was provided, discarding the default head module.")
else:
input_head = self._head
return RobertaModelBundle.build_model(encoder_conf=self._encoder_conf,
head=input_head,
freeze_encoder=freeze_encoder,
checkpoint=self._path if load_weights else None,
override_checkpoint_head=True,
strict=True,
dl_kwargs=dl_kwargs)