in agents/obj_nets.py [0:0]
def forward(self, t):
# tensor is B x T x N x F
# mask padding
tensor = t.clone()
if self.aggregation == 'mlp':
# in case agent.strip_padding is true and < max num actions, pad
num_padding = phyre_simulator.MAX_NUM_OBJECTS - tensor.shape[2]
if num_padding > 0:
pad_zeros = torch.zeros(tensor.shape[0], tensor.shape[1],
num_padding,
tensor.shape[3]).to(tensor.device)
tensor = torch.cat((tensor, pad_zeros), dim=2)
row_sum = torch.sum(tensor, dim=-1)
is_pad = row_sum == 0
#mask = is_pad
tensor_enc = self.encoder(tensor)
# tensor is B x T x N x 128
tensor = self.pos_encoder(tensor_enc) * math.sqrt(self.n_inp)
# tensor is B x T x N x 128
tensor = torch.flatten(tensor, start_dim=1, end_dim=2)
mask = torch.flatten(is_pad, start_dim=1, end_dim=2)
# mask is B x (T x N)
# tensor is B x (T x N) x 128
tensor = tensor.permute(1, 0, 2)
# tensor is (T x N) x B x 128
if self.embed_tf:
tensor = self.transformer_encoder(tensor,
src_key_padding_mask=mask)
# tensor is (T x N) x B x 128
if self.shuffle_embed:
indicies = torch.randperm(tensor.shape[0])
tensor = tensor[indicies]
if self.aggregation == 'mlp_copy_row':
first_elem = tensor[0]
tensor = first_elem.unsqueeze(0).expand(tensor.shape)
tensor = tensor.permute(1, 0, 2)
# tensor is B x (T x N) x 128
tensor = torch.flatten(tensor, start_dim=1, end_dim=-1)
# tensor is B x (T x N x 128)
scores = self.score(tensor).squeeze(-1)
if self.aggregation == 'mlp':
tensor = tensor.permute(1, 0, 2)
# tensor is B x (T x N) x 128
tensor = torch.flatten(tensor, start_dim=1, end_dim=-1)
# tensor is B x (T x N x 128)
scores = self.score(tensor).squeeze(-1)
elif self.aggregation == 'mean':
# tensor is (T x N) x B x 128
tensor = tensor.permute(1, 0, 2)
# tensor is B x (T x N) x 128
tensor = torch.flatten(tensor, start_dim=1, end_dim=-1)
# tensor is B x (T x N x 128)
scores = torch.mean(tensor, dim=-1)
elif self.aggregation == 'mlp_mean':
# tensor is (T x N) x B x 128
tensor = tensor.permute(1, 0, 2)
# tensor is B x (T x N) x 128
tensor = self.score_obj(tensor).squeeze(-1)
# tensor is B x (T x N)
scores = torch.mean(tensor, dim=-1)
elif self.aggregation == 'mean_pool_over_objects':
tensor = tensor.permute(1, 0, 2)
# tensor is B x (T x N) x E
tensor = tensor.reshape(tensor_enc.shape)
# tensor is B x T x N x E
mean_pooled_obj = torch.mean(tensor, dim=2).squeeze(2)
# mean_pooled_obj is B x T x E
flattened_pooled = torch.flatten(mean_pooled_obj, start_dim=1)
# flattened_pooled is B x (T x E)
scores = self.score_timestep_embeddings(flattened_pooled).squeeze(
-1)
# scores is [B,]
elif self.aggregation == 'max_pool_over_objects':
tensor = tensor.permute(1, 0, 2)
# tensor is B x (T x N) x E
tensor = tensor.reshape(tensor_enc.shape)
# tensor is B x T x N x E
mean_pooled_obj = torch.max(tensor, dim=2).values.squeeze(2)
# mean_pooled_obj is B x T x E
flattened_pooled = torch.flatten(mean_pooled_obj, start_dim=1)
# flattened_pooled is B x (T x E)
scores = self.score_timestep_embeddings(flattened_pooled).squeeze(
-1)
# scores is [B,]
elif self.aggregation == 'max':
# tensor is (T x N) x B x 128
tensor = tensor.permute(1, 0, 2)
# tensor is B x (T x N) x 128
tensor = torch.flatten(tensor, start_dim=1, end_dim=-1)
# tensor is B x (T x N x 128)
scores = torch.max(tensor, dim=-1).values.squeeze(-1)
elif self.aggregation == 'mlp_max':
# tensor is (T x N) x B x 128
tensor = tensor.permute(1, 0, 2)
# tensor is B x (T x N) x 128
tensor = self.score_obj(tensor).squeeze(-1)
# tensor is B x (T x N)
scores = torch.max(tensor, dim=-1).values.squeeze(-1)
elif self.aggregation == 'goal_mlp':
tensor = tensor.permute(1, 0, 2)
# tensor is B x (T x N) x E
tensor = tensor.reshape(tensor_enc.shape)
goal_tensor = t.clone()
purple = goal_tensor[:, :, :, -3] == 1.
blue = goal_tensor[:, :, :, -4] == 1.
green = goal_tensor[:, :, :, -5] == 1.
goal_indicies = purple + blue + green
goal_objs = tensor[goal_indicies].reshape(
(tensor.shape[0], tensor.shape[1], -1, tensor.shape[-1]))
goal_objs = torch.flatten(goal_objs, start_dim=1)
scores = self.score_goal(goal_objs).squeeze(-1)
# scores is B x 1
return scores