in empchat/transformer_local.py [0:0]
def __init__(self, opt, dictionary):
super(TransformerAdapter, self).__init__()
self.opt = opt
self.pad_idx = dictionary[PAD_TOKEN]
self.embeddings = nn.Embedding(
len(dictionary), opt.embeddings_size, padding_idx=self.pad_idx
)
if not opt.learn_embeddings:
self.embeddings.weight.requires_grad = False
nn.init.normal_(self.embeddings.weight, mean=0, std=0.05)
dropout = opt.transformer_dropout if opt.transformer_dropout else 0
self.ctx_transformer = TransformerModel(
opt.transformer_n_heads,
opt.n_layers,
opt.transformer_dim,
len(dictionary),
embedding=self.embeddings,
dropout=dropout,
)
self.cand_transformer = TransformerModel(
opt.transformer_n_heads,
opt.n_layers,
opt.transformer_dim,
len(dictionary),
embedding=self.embeddings,
dropout=dropout,
)
self.embeddings = self.ctx_transformer.embeddings