def update()

in ss_baselines/savi/models/belief_predictor.py [0:0]


    def update(self, observations, dones):
        """
        update the current observations with estimated pointgoal in the agent's current coordinate frame
        if spectrogram in the current obs is zero, transform last estimate to agent's current coordinate frame
        """
        batch_size = observations[SpectrogramSensor.cls_uuid].size(0)
        if self.predict_label or self.predict_location:
            spectrograms = observations[SpectrogramSensor.cls_uuid].permute(0, 3, 1, 2)

        if self.predict_location:
            # predicted pointgoal: X is rightward, -Y is forward, heading increases X to Y, agent faces -Y
            with torch.no_grad():
                pointgoals = self.cnn_forward(observations).cpu().numpy()

            for i in range(batch_size):
                pose = observations['pose'][i].cpu().numpy()
                pointgoal = pointgoals[i]
                if dones is not None and dones[i]:
                    self.last_pointgoal[i] = None

                if observations[SpectrogramSensor.cls_uuid][i].sum().item() != 0:
                    # pointgoal_with_gps_compass: X is forward, Y is rightward,
                    # pose: same XY but heading is positive from X to -Y defined based on the initial pose
                    pointgoal_base = np.array([-pointgoal[1], pointgoal[0]])
                    if self.last_pointgoal[i] is None:
                        pointgoal_avg = pointgoal_base
                    else:
                        if self.config.current_pred_only:
                            pointgoal_avg = pointgoal_base
                        else:
                            w = self.config.weighting_factor
                            pointgoal_avg = (1-w) * pointgoal_base + w * odom_to_base(self.last_pointgoal[i], pose)
                    self.last_pointgoal[i] = base_to_odom(pointgoal_avg, pose)
                else:
                    if self.last_pointgoal[i] is None:
                        pointgoal_avg = np.array([10, 10])
                    else:
                        pointgoal_avg = odom_to_base(self.last_pointgoal[i], pose)

                observations[LocationBelief.cls_uuid][i].copy_(torch.from_numpy(pointgoal_avg))

        if self.predict_label:
            with torch.no_grad():
                labels = self.classifier(spectrograms)[:, :21].cpu().numpy()

            for i in range(batch_size):
                label = labels[i]
                if dones is not None and dones[i]:
                    self.last_label[i] = None

                if observations[SpectrogramSensor.cls_uuid][i].sum().item() != 0:
                    if self.last_label[i] is None:
                        label_avg = label
                    else:
                        if self.config.current_pred_only:
                            label_avg = label
                        else:
                            w = self.config.weighting_factor
                            label_avg = (1-w) * label + w * self.last_label[i]
                    self.last_label[i] = label_avg
                else:
                    if self.last_label[i] is None:
                        logging.debug("Empty RIR after done")
                        label_avg = np.ones(21) / 21
                    else:
                        label_avg = self.last_label[i]
                observations[CategoryBelief.cls_uuid][i].copy_(torch.from_numpy(label_avg))