in src/model.py [0:0]
def __init__(self, dimension_params, metadata_constructor_params, layer_params):
super().__init__()
self.emb_dim = dimension_params["emb_dim"]
self.hidden_dim = dimension_params["hidden_dim"]
self.vocab_size = dimension_params["vocab_size"]
self.md_dims = metadata_constructor_params["md_dims"]
self.md_group_sizes = metadata_constructor_params["md_group_sizes"]
self.use_md = True if self.md_dims and self.md_group_sizes else False
self.n_layers = layer_params["n_layers"]
self.use_weight_tying = layer_params["use_weight_tying"]
self.embeddings = nn.Embedding(self.vocab_size, self.emb_dim)
self.lstm = nn.LSTM(input_size=self.emb_dim,
hidden_size=self.hidden_dim,
num_layers=self.n_layers,
batch_first=True,
)
if self.use_weight_tying:
self.vocab_projection = nn.Linear(self.emb_dim, self.vocab_size)
self.embedding_projection = nn.Linear(self.hidden_dim, self.emb_dim)
self.vocab_projection.weight = self.embeddings.weight
else:
self.vocab_projection = nn.Linear(self.hidden_dim, self.vocab_size)
if self.use_md:
self.metadata_constructor = MetadataConstructor(metadata_constructor_params,
dimension_params)