def forward()

in agents/obj_nets.py [0:0]


    def forward(self, t):
        # tensor is B x T x N x F
        # mask padding
        tensor = t.clone()
        if self.aggregation == 'mlp':
            # in case agent.strip_padding is true and < max num actions, pad
            num_padding = phyre_simulator.MAX_NUM_OBJECTS - tensor.shape[2]
            if num_padding > 0:
                pad_zeros = torch.zeros(tensor.shape[0], tensor.shape[1],
                                        num_padding,
                                        tensor.shape[3]).to(tensor.device)
                tensor = torch.cat((tensor, pad_zeros), dim=2)
        row_sum = torch.sum(tensor, dim=-1)
        is_pad = row_sum == 0
        #mask = is_pad
        tensor_enc = self.encoder(tensor)
        # tensor is B x T x N x 128
        tensor = self.pos_encoder(tensor_enc) * math.sqrt(self.n_inp)
        # tensor is B x T x N x 128
        tensor = torch.flatten(tensor, start_dim=1, end_dim=2)

        mask = torch.flatten(is_pad, start_dim=1, end_dim=2)
        # mask is B x (T x N)
        # tensor is B x (T x N) x 128
        tensor = tensor.permute(1, 0, 2)
        # tensor is (T x N) x B x 128
        if self.embed_tf:
            tensor = self.transformer_encoder(tensor,
                                              src_key_padding_mask=mask)
        # tensor is (T x N) x B x 128
        if self.shuffle_embed:
            indicies = torch.randperm(tensor.shape[0])
            tensor = tensor[indicies]

        if self.aggregation == 'mlp_copy_row':
            first_elem = tensor[0]
            tensor = first_elem.unsqueeze(0).expand(tensor.shape)
            tensor = tensor.permute(1, 0, 2)
            # tensor is  B x (T x N) x 128
            tensor = torch.flatten(tensor, start_dim=1, end_dim=-1)
            # tensor is  B x (T x N x 128)
            scores = self.score(tensor).squeeze(-1)

        if self.aggregation == 'mlp':
            tensor = tensor.permute(1, 0, 2)
            # tensor is  B x (T x N) x 128
            tensor = torch.flatten(tensor, start_dim=1, end_dim=-1)
            # tensor is  B x (T x N x 128)
            scores = self.score(tensor).squeeze(-1)
        elif self.aggregation == 'mean':
            # tensor is (T x N) x B x 128
            tensor = tensor.permute(1, 0, 2)
            # tensor is  B x (T x N) x 128
            tensor = torch.flatten(tensor, start_dim=1, end_dim=-1)
            # tensor is  B x (T x N x 128)
            scores = torch.mean(tensor, dim=-1)
        elif self.aggregation == 'mlp_mean':
            # tensor is (T x N) x B x 128
            tensor = tensor.permute(1, 0, 2)
            # tensor is  B x (T x N) x 128
            tensor = self.score_obj(tensor).squeeze(-1)
            # tensor is  B x (T x N)
            scores = torch.mean(tensor, dim=-1)
        elif self.aggregation == 'mean_pool_over_objects':
            tensor = tensor.permute(1, 0, 2)
            # tensor is B x (T x N) x E
            tensor = tensor.reshape(tensor_enc.shape)
            # tensor is B x T x N x E
            mean_pooled_obj = torch.mean(tensor, dim=2).squeeze(2)
            # mean_pooled_obj is B x T x E
            flattened_pooled = torch.flatten(mean_pooled_obj, start_dim=1)
            # flattened_pooled is B x (T x E)
            scores = self.score_timestep_embeddings(flattened_pooled).squeeze(
                -1)
            # scores is [B,]
        elif self.aggregation == 'max_pool_over_objects':
            tensor = tensor.permute(1, 0, 2)
            # tensor is B x (T x N) x E
            tensor = tensor.reshape(tensor_enc.shape)
            # tensor is B x T x N x E
            mean_pooled_obj = torch.max(tensor, dim=2).values.squeeze(2)
            # mean_pooled_obj is B x T x E
            flattened_pooled = torch.flatten(mean_pooled_obj, start_dim=1)
            # flattened_pooled is B x (T x E)
            scores = self.score_timestep_embeddings(flattened_pooled).squeeze(
                -1)
            # scores is [B,]
        elif self.aggregation == 'max':
            # tensor is (T x N) x B x 128
            tensor = tensor.permute(1, 0, 2)
            # tensor is  B x (T x N) x 128
            tensor = torch.flatten(tensor, start_dim=1, end_dim=-1)
            # tensor is  B x (T x N x 128)
            scores = torch.max(tensor, dim=-1).values.squeeze(-1)
        elif self.aggregation == 'mlp_max':
            # tensor is (T x N) x B x 128
            tensor = tensor.permute(1, 0, 2)
            # tensor is  B x (T x N) x 128
            tensor = self.score_obj(tensor).squeeze(-1)
            # tensor is  B x (T x N)
            scores = torch.max(tensor, dim=-1).values.squeeze(-1)
        elif self.aggregation == 'goal_mlp':
            tensor = tensor.permute(1, 0, 2)
            # tensor is B x (T x N) x E
            tensor = tensor.reshape(tensor_enc.shape)
            goal_tensor = t.clone()
            purple = goal_tensor[:, :, :, -3] == 1.
            blue = goal_tensor[:, :, :, -4] == 1.
            green = goal_tensor[:, :, :, -5] == 1.

            goal_indicies = purple + blue + green
            goal_objs = tensor[goal_indicies].reshape(
                (tensor.shape[0], tensor.shape[1], -1, tensor.shape[-1]))
            goal_objs = torch.flatten(goal_objs, start_dim=1)
            scores = self.score_goal(goal_objs).squeeze(-1)
        # scores is  B x 1
        return scores