in mm_action_prediction/models/decoder.py [0:0]
def __init__(self, params):
super(GenerativeDecoder, self).__init__()
self.params = params
# Dialog context encoders.
self.DIALOG_CONTEXT_ENCODERS = ("hierarchical_recurrent", "memory_network")
# Word embedding.
self.word_embed_net = nn.Embedding(
params["vocab_size"], params["word_embed_size"]
)
# Text encoder.
if params["text_encoder"] == "transformer":
if params["encoder"] != "pretrained_transformer":
decoder_layer = nn.TransformerDecoderLayer(
params["word_embed_size"],
params["num_heads_transformer"],
params["hidden_size_transformer"],
)
self.decoder_unit = nn.TransformerDecoder(
decoder_layer, params["num_layers_transformer"]
)
self.pos_encoder = models.PositionalEncoding(
params["word_embed_size"]
)
else:
self.decoder_unit = None
self.no_peek_mask = None
elif params["text_encoder"] == "lstm":
input_size = params["word_embed_size"]
if params["encoder"] in self.DIALOG_CONTEXT_ENCODERS:
input_size += params["hidden_size"]
if params["use_bahdanau_attention"]:
input_size += params["hidden_size"]
if params["encoder"] == "tf_idf":
# If encoder is tf_idf, simple decoder.
input_size = params["word_embed_size"]
self.decoder_unit = nn.LSTM(
input_size,
params["hidden_size"],
params["num_layers"],
batch_first=True,
)
else:
raise NotImplementedError("Text encoder must be not transformer or LSTM!")
input_size = params["hidden_size"]
if params["use_bahdanau_attention"] and params["text_encoder"] == "lstm":
output_size = params["hidden_size"]
# if self.params['use_action_output']:
# input_size += hidden_size
self.attention_net = nn.Linear(input_size, output_size)
input_size += params["hidden_size"]
# If action outputs are to be used.
if params["use_action_output"]:
self.action_fusion_net = nn.Linear(
3 * params["hidden_size"], params["hidden_size"]
)
# Reset the input_size if tf_idf.
if params["encoder"] == "tf_idf":
input_size = params["hidden_size"]
self.inv_word_net = nn.Linear(input_size, params["vocab_size"])
self.criterion = nn.CrossEntropyLoss(reduction="none")