def run()

in ss_baselines/savi/pretraining/audiogoal_trainer.py [0:0]


    def run(self, splits, writer=None):
        meta_dir = self.config.TASK_CONFIG.SIMULATOR.AUDIO.METADATA_DIR

        datasets = dict()
        dataloaders = dict()
        dataset_sizes = dict()
        for split in splits:
            scenes = SCENE_SPLITS[split]
            # use collect subgoal dataset
            scene_graphs = dict()
            for scene in scenes:
                points, graph = load_metadata(os.path.join(meta_dir, 'mp3d', scene))
                scene_graphs[scene] = graph
            datasets[split] = AudioGoalDataset(
                scene_graphs=scene_graphs,
                scenes=scenes,
                split=split,
                use_polar_coordinates=False,
                use_cache=True
            )
            dataloaders[split] = DataLoader(dataset=datasets[split],
                                            batch_size=self.batch_size,
                                            shuffle=False,
                                            pin_memory=True,
                                            num_workers=self.num_worker,
                                            sampler=None,
                                            )

            dataset_sizes[split] = len(datasets[split])
            print('{} has {} samples'.format(split.upper(), dataset_sizes[split]))

        regressor_criterion = nn.MSELoss().to(device=self.device)
        classifier_criterion = nn.CrossEntropyLoss().to(device=self.device)
        model = self.audiogoal_predictor
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))

        # training params
        since = time.time()
        best_acc = 0
        best_model_wts = None
        num_epoch = self.num_epoch if 'train' in splits else 1
        for epoch in range(num_epoch):
            logging.info('-' * 10)
            logging.info('Epoch {}/{}'.format(epoch, num_epoch))

            # Each epoch has a training and validation phase
            for split in splits:
                if split == 'train':
                    self.audiogoal_predictor.train()  # Set model to training mode
                else:
                    self.audiogoal_predictor.eval()  # Set model to evaluate mode

                running_total_loss = 0.0
                running_regressor_loss = 0.0
                running_classifier_loss = 0.0
                running_regressor_corrects = 0
                running_classifier_corrects = 0

                # Iterating over data once is one epoch
                for i, data in enumerate(tqdm(dataloaders[split])):
                    # get the inputs
                    inputs, gts = data

                    # remove alpha channel
                    inputs = [x.to(device=self.device, dtype=torch.float) for x in inputs]
                    gts = gts.to(device=self.device, dtype=torch.float)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    predicts = model({input_type: x for input_type, x in zip(['spectrogram'], inputs)})

                    if self.predict_label and self.predict_location:
                        classifier_loss = classifier_criterion(predicts[:, :-2], gts[:, 0].long())
                        regressor_loss = regressor_criterion(predicts[:, -2:], gts[:, -2:])
                    elif self.predict_label:
                        classifier_loss = classifier_criterion(predicts, gts[:, 0].long())
                        regressor_loss = torch.tensor([0], device=self.device)
                    elif self.predict_location:
                        regressor_loss = regressor_criterion(predicts, gts[:, -2:])
                        classifier_loss = torch.tensor([0], device=self.device)
                    else:
                        raise ValueError('Must predict one item.')
                    loss = classifier_loss + regressor_loss

                    # backward + optimize only if in training phase
                    if split == 'train':
                        loss.backward()
                        optimizer.step()

                    running_total_loss += loss.item() * gts.size(0)
                    running_classifier_loss += classifier_loss.item() * gts.size(0)
                    running_regressor_loss += regressor_loss.item() * gts.size(0)

                    pred_x = np.round(predicts.cpu().detach().numpy())
                    pred_y = np.round(predicts.cpu().detach().numpy())
                    gt_x = np.round(gts.cpu().numpy())
                    gt_y = np.round(gts.cpu().numpy())

                    # hard accuracy
                    if self.predict_label and self.predict_location:
                        running_regressor_corrects += np.sum(np.bitwise_and(
                            pred_x[:, -2] == gt_x[:, -2], pred_y[:, -1] == gt_y[:, -1]))
                        running_classifier_corrects += torch.sum(
                            torch.argmax(torch.abs(predicts[:, :-2]), dim=1) == gts[:, 0]).item()
                    elif self.predict_label:
                        running_classifier_corrects += torch.sum(
                            torch.argmax(torch.abs(predicts), dim=1) == gts[:, 0]).item()
                        running_regressor_corrects = 0
                    elif self.predict_location:
                        running_regressor_corrects += np.sum(np.bitwise_and(
                            pred_x[:, 0] == gt_x[:, -2], pred_y[:, 1] == gt_y[:, -1]))
                        running_classifier_corrects = 0

                epoch_total_loss = running_total_loss / dataset_sizes[split]
                epoch_regressor_loss = running_regressor_loss / dataset_sizes[split]
                epoch_classifier_loss = running_classifier_loss / dataset_sizes[split]
                epoch_regressor_acc = running_regressor_corrects / dataset_sizes[split]
                epoch_classifier_acc = running_classifier_corrects / dataset_sizes[split]
                if writer is not None:
                    writer.add_scalar(f'Loss/{split}_total', epoch_total_loss, epoch)
                    writer.add_scalar(f'Loss/{split}_classifier', epoch_classifier_loss, epoch)
                    writer.add_scalar(f'Loss/{split}_regressor', epoch_regressor_loss, epoch)
                    writer.add_scalar(f'Accuracy/{split}_classifier', epoch_classifier_acc, epoch)
                    writer.add_scalar(f'Accuracy/{split}_regressor', epoch_regressor_acc, epoch)
                logging.info(f'{split.upper()} Total loss: {epoch_total_loss:.4f}, '
                             f'label loss: {epoch_classifier_loss:.4f}, xy loss: {epoch_regressor_loss},'
                             f' label acc: {epoch_classifier_acc:.4f}, xy acc: {epoch_regressor_acc}')

                # deep copy the model
                if self.predict_label and self.predict_location:
                    target_acc = epoch_regressor_acc + epoch_classifier_acc
                elif self.predict_location:
                    target_acc = epoch_regressor_acc
                else:
                    target_acc = epoch_classifier_acc

                if split == 'val' and target_acc > best_acc:
                    best_acc = target_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    self.save_checkpoint(f"ckpt.{epoch}.pth")

        self.save_checkpoint(f"best_val.pth", checkpoint={"audiogoal_predictor": best_model_wts})

        time_elapsed = time.time() - since
        logging.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        logging.info('Best val acc: {:4f}'.format(best_acc))

        if best_model_wts is not None:
            model.load_state_dict(best_model_wts)