in ttw/models/language.py [0:0]
def __init__(self, inp_emb_sz, hidden_sz, num_tokens, apply_masc=True, T=1):
super(GuideLanguage, self).__init__()
self.hidden_sz = hidden_sz
self.inp_emb_sz = inp_emb_sz
self.num_tokens = num_tokens
self.apply_masc = apply_masc
self.T = T
self.embed_fn = nn.Embedding(num_tokens, inp_emb_sz, padding_idx=0)
self.encoder_fn = nn.LSTM(inp_emb_sz, hidden_sz//2, batch_first=True, bidirectional=True)
self.cbow_fn = CBoW(11, hidden_sz)
self.T_prediction_fn = nn.Linear(hidden_sz, T+1)
self.feat_control_emb = nn.Parameter(torch.FloatTensor(hidden_sz).normal_(0.0, 0.1))
self.feat_control_step_fn = ControlStep(hidden_sz)
if apply_masc:
self.act_control_emb = nn.Parameter(torch.FloatTensor(hidden_sz).normal_(0.0, 0.1))
self.act_control_step_fn = ControlStep(hidden_sz)
self.action_linear_fn = nn.Linear(hidden_sz, 9)
self.landmark_write_gate = nn.ParameterList()
self.obs_write_gate = nn.ParameterList()
for _ in range(T + 1):
self.landmark_write_gate.append(nn.Parameter(torch.FloatTensor(1, hidden_sz, 1, 1).normal_(0, 0.1)))
self.obs_write_gate.append(nn.Parameter(torch.FloatTensor(1, hidden_sz).normal_(0.0, 0.1)))
if apply_masc:
self.masc_fn = MASC(self.hidden_sz)
else:
self.masc_fn = NoMASC(self.hidden_sz)
self.loss = nn.CrossEntropyLoss()