in agents/obj_nets.py [0:0]
def forward(self, tensor):
tensor = tensor.clone()
receiver_matrix = []
sender_matrix = []
relation_matrix = []
external_effects_matrix = []
for i in range(tensor.shape[0]):
black = tensor[i, 0, :, -1] == 1.
purple = tensor[i, 0, :, -3] == 1.
static_indicies = (black + purple).type(torch.IntTensor)
row_sum = torch.sum(tensor[i, 0], dim=-1)
pad = 2 * (row_sum == 0).type(torch.IntTensor)
type_values = pad + static_indicies
type_tuple = tuple(type_values.tolist())
if type_tuple not in self.relation_matricies:
rec, send, rel = self._initialize_relations_matricies(
tensor[i, 0])
self.relation_matricies[type_tuple] = (rec, send, rel)
else:
(rec, send, rel) = self.relation_matricies[type_tuple]
rec = rec.to(tensor.device)
send = send.to(tensor.device)
rel = rel.to(tensor.device)
if type_tuple not in self.external_effects_matricies:
ext = self._intialize_external_effects_matrix(tensor[i, 0])
self.external_effects_matricies[type_tuple] = ext
else:
ext = self.external_effects_matricies[type_tuple]
ext = ext.to(tensor.device)
external_effects_matrix.append(ext)
receiver_matrix.append(rec)
sender_matrix.append(send)
relation_matrix.append(rel)
external_effects_matrix = torch.stack(external_effects_matrix, dim=0)
receiver_matrix = torch.stack(receiver_matrix, dim=0)
sender_matrix = torch.stack(sender_matrix, dim=0)
relation_matrix = torch.stack(relation_matrix, dim=0)
# input tensor is B x n-frames (2) x N-objects x VECTOR_LENGTH
assert tensor.shape[1] == 2
velocities = tensor.select(1, -1).narrow(2, 0, 3) - tensor.select(
1, 0).narrow(2, 0, 3)
# object_features is B x N-objects x (3 (velocities) + VECTOR_LENGTH)
# Ds = (3 (velocities) + VECTOR_LENGTH)
object_features = torch.cat((velocities, tensor.select(1, -1)), dim=-1)
# permute to B x state x N objects
object_features = object_features.permute(0, 2, 1)
# Dr is 1 (all 0's)
# Nr = number of relations = # dynamic objects * (# objects - 1)
# interactions is a B x (2Ds + Dr) x Nr
interactions = self._marshall_interactions(object_features,
receiver_matrix,
sender_matrix,
relation_matrix)
#effects should be B x self.interactions_size # Nr
effects = self._compute_effects(interactions)
#aggregated_effects should be B x (DS + DX + self.interactions_size) x #Number objects
aggregated_effects = self._aggregate_effects(effects, receiver_matrix,
external_effects_matrix,
object_features)
#aggregated_effects matrix is B x (DS ({DS}) + DX ({DX}) + self.#interactions_size x N-objects
# states is a B x 3 x N objects matrix of velocities
states = self._compute_states(aggregated_effects)
# transpose back to B x N x 3
states = states.permute(0, 2, 1)
# add to previous state to get predictions
next_state = tensor.select(1, -1)
if self.dont_predict_ball_theta_v:
is_ball = next_state[:, :, 4] == 1.
if len(states[is_ball]) > 0:
# This doesn't actually state values, remove or update
states[is_ball][-1] = 0.
if self.only_pred_dynamic:
black = next_state[:, :, -1] == 1.
purple = next_state[:, :, -3] == 1.
static_indicies = black + purple
states[static_indicies] = 0.0
row_sum = torch.sum(next_state, dim=-1)
states[row_sum == 0] = 0.0
prediction = next_state.clone()
prediction[:, :, :3] += states
if self.clip_output is not None:
prediction[:, :, :3] = self.clip_output(
prediction[:, :, :3].clone()).clone()
if self.only_pred_dynamic:
prediction[:, :, :3][static_indicies] = next_state[:, :, :3][
static_indicies]
return prediction