in pytorch_translate/char_source_model.py [0:0]
def build_model(cls, args, task):
"""Build a new model instance."""
src_dict, dst_dict = task.source_dictionary, task.target_dictionary
base_architecture(args)
assert args.sequence_lstm, "CharRNNModel only supports sequence_lstm"
assert args.cell_type == "lstm", "CharRNNModel only supports cell_type lstm"
assert hasattr(args, "char_source_dict_size"), (
"args.char_source_dict_size required. "
"should be set by load_binarized_dataset()"
)
if hasattr(args, "char_cnn_params"):
args.embed_bytes = getattr(args, "embed_bytes", False)
# If we embed bytes then the number of indices is fixed and does not
# depend on the dictionary
if args.embed_bytes:
num_chars = vocab_constants.NUM_BYTE_INDICES + TAGS.__len__() + 1
else:
num_chars = args.char_source_dict_size
# In case use_pretrained_weights is true, verify the model params
# are correctly set
if args.embed_bytes and getattr(args, "use_pretrained_weights", False):
verify_pretrain_params(args)
encoder = CharCNNEncoder(
src_dict,
num_chars=num_chars,
unk_only_char_encoding=args.unk_only_char_encoding,
embed_dim=args.char_embed_dim,
token_embed_dim=args.encoder_embed_dim,
freeze_embed=args.encoder_freeze_embed,
normalize_embed=args.encoder_normalize_embed,
char_cnn_params=args.char_cnn_params,
char_cnn_nonlinear_fn=args.char_cnn_nonlinear_fn,
char_cnn_num_highway_layers=args.char_cnn_num_highway_layers,
char_cnn_output_dim=getattr(args, "char_cnn_output_dim", -1),
num_layers=args.encoder_layers,
hidden_dim=args.encoder_hidden_dim,
dropout_in=args.encoder_dropout_in,
dropout_out=args.encoder_dropout_out,
residual_level=args.residual_level,
bidirectional=bool(args.encoder_bidirectional),
use_pretrained_weights=getattr(args, "use_pretrained_weights", False),
finetune_pretrained_weights=getattr(
args, "finetune_pretrained_weights", False
),
weights_file=getattr(args, "pretrained_weights_file", ""),
)
else:
assert (
args.unk_only_char_encoding is False
), "unk_only_char_encoding should be False when using CharRNNEncoder"
encoder = CharRNNEncoder(
src_dict,
num_chars=args.char_source_dict_size,
char_embed_dim=args.char_embed_dim,
token_embed_dim=args.encoder_embed_dim,
normalize_embed=args.encoder_normalize_embed,
char_rnn_units=args.char_rnn_units,
char_rnn_layers=args.char_rnn_layers,
num_layers=args.encoder_layers,
hidden_dim=args.encoder_hidden_dim,
dropout_in=args.encoder_dropout_in,
dropout_out=args.encoder_dropout_out,
residual_level=args.residual_level,
bidirectional=bool(args.encoder_bidirectional),
)
decoder_embed_tokens = Embedding(
num_embeddings=len(dst_dict),
embedding_dim=args.decoder_embed_dim,
padding_idx=dst_dict.pad(),
freeze_embed=args.decoder_freeze_embed,
)
utils.load_embedding(
embedding=decoder_embed_tokens,
dictionary=dst_dict,
pretrained_embed=args.decoder_pretrained_embed,
)
decoder = rnn.RNNDecoder(
src_dict=src_dict,
dst_dict=dst_dict,
embed_tokens=decoder_embed_tokens,
vocab_reduction_params=args.vocab_reduction_params,
encoder_hidden_dim=args.encoder_hidden_dim,
embed_dim=args.decoder_embed_dim,
out_embed_dim=args.decoder_out_embed_dim,
cell_type=args.cell_type,
num_layers=args.decoder_layers,
hidden_dim=args.decoder_hidden_dim,
attention_type=args.attention_type,
dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out,
residual_level=args.residual_level,
averaging_encoder=args.averaging_encoder,
)
return cls(task, encoder, decoder)