in ss_baselines/savi/models/belief_predictor.py [0:0]
def update(self, observations, dones):
"""
update the current observations with estimated pointgoal in the agent's current coordinate frame
if spectrogram in the current obs is zero, transform last estimate to agent's current coordinate frame
"""
batch_size = observations[SpectrogramSensor.cls_uuid].size(0)
if self.predict_label or self.predict_location:
spectrograms = observations[SpectrogramSensor.cls_uuid].permute(0, 3, 1, 2)
if self.predict_location:
# predicted pointgoal: X is rightward, -Y is forward, heading increases X to Y, agent faces -Y
with torch.no_grad():
pointgoals = self.cnn_forward(observations).cpu().numpy()
for i in range(batch_size):
pose = observations['pose'][i].cpu().numpy()
pointgoal = pointgoals[i]
if dones is not None and dones[i]:
self.last_pointgoal[i] = None
if observations[SpectrogramSensor.cls_uuid][i].sum().item() != 0:
# pointgoal_with_gps_compass: X is forward, Y is rightward,
# pose: same XY but heading is positive from X to -Y defined based on the initial pose
pointgoal_base = np.array([-pointgoal[1], pointgoal[0]])
if self.last_pointgoal[i] is None:
pointgoal_avg = pointgoal_base
else:
if self.config.current_pred_only:
pointgoal_avg = pointgoal_base
else:
w = self.config.weighting_factor
pointgoal_avg = (1-w) * pointgoal_base + w * odom_to_base(self.last_pointgoal[i], pose)
self.last_pointgoal[i] = base_to_odom(pointgoal_avg, pose)
else:
if self.last_pointgoal[i] is None:
pointgoal_avg = np.array([10, 10])
else:
pointgoal_avg = odom_to_base(self.last_pointgoal[i], pose)
observations[LocationBelief.cls_uuid][i].copy_(torch.from_numpy(pointgoal_avg))
if self.predict_label:
with torch.no_grad():
labels = self.classifier(spectrograms)[:, :21].cpu().numpy()
for i in range(batch_size):
label = labels[i]
if dones is not None and dones[i]:
self.last_label[i] = None
if observations[SpectrogramSensor.cls_uuid][i].sum().item() != 0:
if self.last_label[i] is None:
label_avg = label
else:
if self.config.current_pred_only:
label_avg = label
else:
w = self.config.weighting_factor
label_avg = (1-w) * label + w * self.last_label[i]
self.last_label[i] = label_avg
else:
if self.last_label[i] is None:
logging.debug("Empty RIR after done")
label_avg = np.ones(21) / 21
else:
label_avg = self.last_label[i]
observations[CategoryBelief.cls_uuid][i].copy_(torch.from_numpy(label_avg))