def forward()

in agents/obj_nets.py [0:0]


    def forward(self, tensor):
        tensor = tensor.clone()

        receiver_matrix = []
        sender_matrix = []
        relation_matrix = []
        external_effects_matrix = []
        for i in range(tensor.shape[0]):
            black = tensor[i, 0, :, -1] == 1.
            purple = tensor[i, 0, :, -3] == 1.
            static_indicies = (black + purple).type(torch.IntTensor)
            row_sum = torch.sum(tensor[i, 0], dim=-1)
            pad = 2 * (row_sum == 0).type(torch.IntTensor)
            type_values = pad + static_indicies
            type_tuple = tuple(type_values.tolist())
            if type_tuple not in self.relation_matricies:
                rec, send, rel = self._initialize_relations_matricies(
                    tensor[i, 0])
                self.relation_matricies[type_tuple] = (rec, send, rel)
            else:
                (rec, send, rel) = self.relation_matricies[type_tuple]
                rec = rec.to(tensor.device)
                send = send.to(tensor.device)
                rel = rel.to(tensor.device)

            if type_tuple not in self.external_effects_matricies:
                ext = self._intialize_external_effects_matrix(tensor[i, 0])
                self.external_effects_matricies[type_tuple] = ext
            else:
                ext = self.external_effects_matricies[type_tuple]
                ext = ext.to(tensor.device)

            external_effects_matrix.append(ext)
            receiver_matrix.append(rec)
            sender_matrix.append(send)
            relation_matrix.append(rel)
        external_effects_matrix = torch.stack(external_effects_matrix, dim=0)
        receiver_matrix = torch.stack(receiver_matrix, dim=0)
        sender_matrix = torch.stack(sender_matrix, dim=0)
        relation_matrix = torch.stack(relation_matrix, dim=0)
        # input tensor is B x n-frames (2) x N-objects x VECTOR_LENGTH
        assert tensor.shape[1] == 2
        velocities = tensor.select(1, -1).narrow(2, 0, 3) - tensor.select(
            1, 0).narrow(2, 0, 3)
        # object_features is B x  N-objects x (3 (velocities) + VECTOR_LENGTH)
        # Ds = (3 (velocities) + VECTOR_LENGTH)
        object_features = torch.cat((velocities, tensor.select(1, -1)), dim=-1)
        # permute to B x state x N objects
        object_features = object_features.permute(0, 2, 1)

        # Dr is 1 (all 0's)
        # Nr = number of relations = # dynamic objects * (# objects - 1)
        # interactions is a B x (2Ds + Dr) x Nr
        interactions = self._marshall_interactions(object_features,
                                                   receiver_matrix,
                                                   sender_matrix,
                                                   relation_matrix)

        #effects should be B x self.interactions_size # Nr
        effects = self._compute_effects(interactions)

        #aggregated_effects should be B x (DS + DX + self.interactions_size) x #Number objects
        aggregated_effects = self._aggregate_effects(effects, receiver_matrix,
                                                     external_effects_matrix,
                                                     object_features)

        #aggregated_effects matrix is B x (DS ({DS}) + DX ({DX}) + self.#interactions_size  x N-objects

        # states is a B x 3 x N objects matrix of velocities

        states = self._compute_states(aggregated_effects)

        # transpose back to B x N x 3
        states = states.permute(0, 2, 1)

        # add to previous state to get predictions
        next_state = tensor.select(1, -1)
        if self.dont_predict_ball_theta_v:
            is_ball = next_state[:, :, 4] == 1.
            if len(states[is_ball]) > 0:
                # This doesn't actually state values, remove or update
                states[is_ball][-1] = 0.
        if self.only_pred_dynamic:
            black = next_state[:, :, -1] == 1.
            purple = next_state[:, :, -3] == 1.
            static_indicies = black + purple
            states[static_indicies] = 0.0
            row_sum = torch.sum(next_state, dim=-1)
            states[row_sum == 0] = 0.0

        prediction = next_state.clone()
        prediction[:, :, :3] += states
        if self.clip_output is not None:
            prediction[:, :, :3] = self.clip_output(
                prediction[:, :, :3].clone()).clone()
            if self.only_pred_dynamic:
                prediction[:, :, :3][static_indicies] = next_state[:, :, :3][
                    static_indicies]
        return prediction