in ttw/models/language.py [0:0]
def __init__(self, act_emb_sz, act_hid_sz, num_actions, obs_emb_sz, obs_hid_sz, num_observations,
decoder_emb_sz, decoder_hid_sz, num_words, start_token=1, end_token=2):
super(TouristLanguage, self).__init__()
self.act_emb_sz = act_emb_sz
self.act_hid_sz = act_hid_sz
self.num_actions = num_actions
self.obs_emb_sz = obs_emb_sz
self.obs_hid_sz = obs_hid_sz
self.num_observations = num_observations
self.decoder_emb_sz = decoder_emb_sz
self.decoder_hid_sz = decoder_hid_sz
self.num_words = num_words
self.act_encoder = GRUEncoder(act_emb_sz, act_hid_sz, num_actions)
self.obs_encoder = GRUEncoder(obs_emb_sz, obs_hid_sz, num_observations, cbow=True)
self.emb_fn = nn.Embedding(num_words, decoder_emb_sz)
self.emb_fn.weight.data.normal_(0.0, 0.1)
self.decoder = nn.GRU(2*decoder_emb_sz, decoder_hid_sz, batch_first=True)
self.context_linear = nn.Linear(act_hid_sz+obs_hid_sz, decoder_emb_sz)
self.out_linear = nn.Linear(decoder_hid_sz, num_words)
self.loss = nn.CrossEntropyLoss(reduce=False)
self.start_token = start_token
self.end_token = end_token