in pytext/models/doc_model.py [0:0]
def torchscriptify(self, tensorizers, traced_model):
output_layer = self.output_layer.torchscript_predictions()
max_seq_len = tensorizers["tokens"].max_seq_len or -1
max_byte_len = tensorizers["token_bytes"].max_byte_len
byte_offset_for_non_padding = tensorizers["token_bytes"].offset_for_non_padding
input_vocab = tensorizers["tokens"].vocab
class Model(jit.ScriptModule):
def __init__(self):
super().__init__()
self.vocab = ScriptVocabulary(
input_vocab,
input_vocab.get_unk_index(),
input_vocab.get_pad_index(),
)
self.max_seq_len = jit.Attribute(max_seq_len, int)
self.max_byte_len = jit.Attribute(max_byte_len, int)
self.byte_offset_for_non_padding = jit.Attribute(
byte_offset_for_non_padding, int
)
self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
self.model = traced_model
self.output_layer = output_layer
@jit.script_method
def forward(
self,
texts: Optional[List[str]] = None,
multi_texts: Optional[List[List[str]]] = None,
tokens: Optional[List[List[str]]] = None,
languages: Optional[List[str]] = None,
):
if tokens is None:
raise RuntimeError("tokens is required")
tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token)
seq_lens = make_sequence_lengths(tokens)
word_ids = self.vocab.lookup_indices_2d(tokens)
word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
token_bytes, _ = make_byte_inputs(
tokens, self.max_byte_len, self.byte_offset_for_non_padding
)
logits = self.model(
torch.tensor(word_ids), token_bytes, torch.tensor(seq_lens)
)
return self.output_layer(logits)
class ModelWithDenseFeat(jit.ScriptModule):
def __init__(self):
super().__init__()
self.vocab = ScriptVocabulary(
input_vocab,
input_vocab.get_unk_index(),
input_vocab.get_pad_index(),
)
self.normalizer = tensorizers["dense"].normalizer
self.max_seq_len = jit.Attribute(max_seq_len, int)
self.max_byte_len = jit.Attribute(max_byte_len, int)
self.byte_offset_for_non_padding = jit.Attribute(
byte_offset_for_non_padding, int
)
self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
self.model = traced_model
self.output_layer = output_layer
@jit.script_method
def forward(
self,
texts: Optional[List[str]] = None,
multi_texts: Optional[List[List[str]]] = None,
tokens: Optional[List[List[str]]] = None,
languages: Optional[List[str]] = None,
dense_feat: Optional[List[List[float]]] = None,
):
if tokens is None:
raise RuntimeError("tokens is required")
if dense_feat is None:
raise RuntimeError("dense_feat is required")
tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token)
seq_lens = make_sequence_lengths(tokens)
word_ids = self.vocab.lookup_indices_2d(tokens)
word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
token_bytes, _ = make_byte_inputs(
tokens, self.max_byte_len, self.byte_offset_for_non_padding
)
dense_feat = self.normalizer.normalize(dense_feat)
logits = self.model(
torch.tensor(word_ids),
token_bytes,
torch.tensor(seq_lens),
torch.tensor(dense_feat, dtype=torch.float),
)
return self.output_layer(logits)
return ModelWithDenseFeat() if "dense" in tensorizers else Model()