def torchscriptify()

in pytext/models/doc_model.py [0:0]


    def torchscriptify(self, tensorizers, traced_model):  # noqa
        output_layer = self.output_layer.torchscript_predictions()

        input_vocab = tensorizers["tokens"].vocab
        max_seq_len = tensorizers["tokens"].max_seq_len or -1
        scripted_tokenizer: Optional[jit.ScriptModule] = None
        try:
            scripted_tokenizer = tensorizers["tokens"].tokenizer.torchscriptify()
        except NotImplementedError:
            pass
        if scripted_tokenizer and isinstance(scripted_tokenizer, DoNothingTokenizer):
            scripted_tokenizer = None

        """
        The input tensor packing memory is allocated/cached for different shapes,
        and max sequence length will help to reduce the number of different tensor
        shapes. We noticed that the TorchScript model could use 25G for offline
        inference on CPU without using max_seq_len.
        """

        class Model(jit.ScriptModule):
            def __init__(self):
                super().__init__()
                self.vocab = ScriptVocabulary(
                    input_vocab,
                    input_vocab.get_unk_index(),
                    input_vocab.get_pad_index(),
                )
                self.model = traced_model
                self.output_layer = output_layer
                self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
                self.max_seq_len = jit.Attribute(max_seq_len, int)
                self.tokenizer = scripted_tokenizer

            @jit.script_method
            def forward(
                self,
                texts: Optional[List[str]] = None,
                multi_texts: Optional[List[List[str]]] = None,
                tokens: Optional[List[List[str]]] = None,
                languages: Optional[List[str]] = None,
            ):
                # PyTorch breaks with 2 'not None' checks right now.
                if texts is not None:
                    if tokens is not None:
                        raise RuntimeError("Can't set both tokens and texts")
                    if self.tokenizer is not None:
                        tokens = [
                            [t[0] for t in self.tokenizer.tokenize(text)]
                            for text in texts
                        ]

                if tokens is None:
                    raise RuntimeError("tokens is required")

                tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token)
                seq_lens = make_sequence_lengths(tokens)
                word_ids = self.vocab.lookup_indices_2d(tokens)
                word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
                logits = self.model(torch.tensor(word_ids), torch.tensor(seq_lens))
                return self.output_layer(logits)

        class ModelWithDenseFeat(jit.ScriptModule):
            def __init__(self):
                super().__init__()
                self.vocab = ScriptVocabulary(
                    input_vocab,
                    input_vocab.get_unk_index(),
                    input_vocab.get_pad_index(),
                )
                self.normalizer = tensorizers["dense"].normalizer
                self.model = traced_model
                self.output_layer = output_layer
                self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
                self.max_seq_len = jit.Attribute(max_seq_len, int)
                self.tokenizer = scripted_tokenizer

            @jit.script_method
            def forward(
                self,
                texts: Optional[List[str]] = None,
                multi_texts: Optional[List[List[str]]] = None,
                tokens: Optional[List[List[str]]] = None,
                languages: Optional[List[str]] = None,
                dense_feat: Optional[List[List[float]]] = None,
            ):
                # PyTorch breaks with 2 'not None' checks right now.
                if texts is not None:
                    if tokens is not None:
                        raise RuntimeError("Can't set both tokens and texts")
                    if self.tokenizer is not None:
                        tokens = [
                            [t[0] for t in self.tokenizer.tokenize(text)]
                            for text in texts
                        ]

                if tokens is None:
                    raise RuntimeError("tokens is required")
                if dense_feat is None:
                    raise RuntimeError("dense_feat is required")

                tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token)
                seq_lens = make_sequence_lengths(tokens)
                word_ids = self.vocab.lookup_indices_2d(tokens)
                word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
                dense_feat = self.normalizer.normalize(dense_feat)
                logits = self.model(
                    torch.tensor(word_ids),
                    torch.tensor(seq_lens),
                    torch.tensor(dense_feat, dtype=torch.float),
                )
                return self.output_layer(logits)

        return ModelWithDenseFeat() if "dense" in tensorizers else Model()