def torchscriptify()

in pytext/models/doc_model.py [0:0]


    def torchscriptify(self, tensorizers, traced_model):
        output_layer = self.output_layer.torchscript_predictions()
        max_seq_len = tensorizers["tokens"].max_seq_len or -1
        max_byte_len = tensorizers["token_bytes"].max_byte_len
        byte_offset_for_non_padding = tensorizers["token_bytes"].offset_for_non_padding
        input_vocab = tensorizers["tokens"].vocab

        class Model(jit.ScriptModule):
            def __init__(self):
                super().__init__()
                self.vocab = ScriptVocabulary(
                    input_vocab,
                    input_vocab.get_unk_index(),
                    input_vocab.get_pad_index(),
                )
                self.max_seq_len = jit.Attribute(max_seq_len, int)
                self.max_byte_len = jit.Attribute(max_byte_len, int)
                self.byte_offset_for_non_padding = jit.Attribute(
                    byte_offset_for_non_padding, int
                )
                self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
                self.model = traced_model
                self.output_layer = output_layer

            @jit.script_method
            def forward(
                self,
                texts: Optional[List[str]] = None,
                multi_texts: Optional[List[List[str]]] = None,
                tokens: Optional[List[List[str]]] = None,
                languages: Optional[List[str]] = None,
            ):
                if tokens is None:
                    raise RuntimeError("tokens is required")

                tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token)
                seq_lens = make_sequence_lengths(tokens)
                word_ids = self.vocab.lookup_indices_2d(tokens)
                word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
                token_bytes, _ = make_byte_inputs(
                    tokens, self.max_byte_len, self.byte_offset_for_non_padding
                )
                logits = self.model(
                    torch.tensor(word_ids), token_bytes, torch.tensor(seq_lens)
                )
                return self.output_layer(logits)

        class ModelWithDenseFeat(jit.ScriptModule):
            def __init__(self):
                super().__init__()
                self.vocab = ScriptVocabulary(
                    input_vocab,
                    input_vocab.get_unk_index(),
                    input_vocab.get_pad_index(),
                )
                self.normalizer = tensorizers["dense"].normalizer
                self.max_seq_len = jit.Attribute(max_seq_len, int)
                self.max_byte_len = jit.Attribute(max_byte_len, int)
                self.byte_offset_for_non_padding = jit.Attribute(
                    byte_offset_for_non_padding, int
                )
                self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
                self.model = traced_model
                self.output_layer = output_layer

            @jit.script_method
            def forward(
                self,
                texts: Optional[List[str]] = None,
                multi_texts: Optional[List[List[str]]] = None,
                tokens: Optional[List[List[str]]] = None,
                languages: Optional[List[str]] = None,
                dense_feat: Optional[List[List[float]]] = None,
            ):
                if tokens is None:
                    raise RuntimeError("tokens is required")
                if dense_feat is None:
                    raise RuntimeError("dense_feat is required")

                tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token)
                seq_lens = make_sequence_lengths(tokens)
                word_ids = self.vocab.lookup_indices_2d(tokens)
                word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
                token_bytes, _ = make_byte_inputs(
                    tokens, self.max_byte_len, self.byte_offset_for_non_padding
                )
                dense_feat = self.normalizer.normalize(dense_feat)
                logits = self.model(
                    torch.tensor(word_ids),
                    token_bytes,
                    torch.tensor(seq_lens),
                    torch.tensor(dense_feat, dtype=torch.float),
                )
                return self.output_layer(logits)

        return ModelWithDenseFeat() if "dense" in tensorizers else Model()