in ss_baselines/savi/pretraining/audiogoal_trainer.py [0:0]
def run(self, splits, writer=None):
meta_dir = self.config.TASK_CONFIG.SIMULATOR.AUDIO.METADATA_DIR
datasets = dict()
dataloaders = dict()
dataset_sizes = dict()
for split in splits:
scenes = SCENE_SPLITS[split]
# use collect subgoal dataset
scene_graphs = dict()
for scene in scenes:
points, graph = load_metadata(os.path.join(meta_dir, 'mp3d', scene))
scene_graphs[scene] = graph
datasets[split] = AudioGoalDataset(
scene_graphs=scene_graphs,
scenes=scenes,
split=split,
use_polar_coordinates=False,
use_cache=True
)
dataloaders[split] = DataLoader(dataset=datasets[split],
batch_size=self.batch_size,
shuffle=False,
pin_memory=True,
num_workers=self.num_worker,
sampler=None,
)
dataset_sizes[split] = len(datasets[split])
print('{} has {} samples'.format(split.upper(), dataset_sizes[split]))
regressor_criterion = nn.MSELoss().to(device=self.device)
classifier_criterion = nn.CrossEntropyLoss().to(device=self.device)
model = self.audiogoal_predictor
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))
# training params
since = time.time()
best_acc = 0
best_model_wts = None
num_epoch = self.num_epoch if 'train' in splits else 1
for epoch in range(num_epoch):
logging.info('-' * 10)
logging.info('Epoch {}/{}'.format(epoch, num_epoch))
# Each epoch has a training and validation phase
for split in splits:
if split == 'train':
self.audiogoal_predictor.train() # Set model to training mode
else:
self.audiogoal_predictor.eval() # Set model to evaluate mode
running_total_loss = 0.0
running_regressor_loss = 0.0
running_classifier_loss = 0.0
running_regressor_corrects = 0
running_classifier_corrects = 0
# Iterating over data once is one epoch
for i, data in enumerate(tqdm(dataloaders[split])):
# get the inputs
inputs, gts = data
# remove alpha channel
inputs = [x.to(device=self.device, dtype=torch.float) for x in inputs]
gts = gts.to(device=self.device, dtype=torch.float)
# zero the parameter gradients
optimizer.zero_grad()
# forward
predicts = model({input_type: x for input_type, x in zip(['spectrogram'], inputs)})
if self.predict_label and self.predict_location:
classifier_loss = classifier_criterion(predicts[:, :-2], gts[:, 0].long())
regressor_loss = regressor_criterion(predicts[:, -2:], gts[:, -2:])
elif self.predict_label:
classifier_loss = classifier_criterion(predicts, gts[:, 0].long())
regressor_loss = torch.tensor([0], device=self.device)
elif self.predict_location:
regressor_loss = regressor_criterion(predicts, gts[:, -2:])
classifier_loss = torch.tensor([0], device=self.device)
else:
raise ValueError('Must predict one item.')
loss = classifier_loss + regressor_loss
# backward + optimize only if in training phase
if split == 'train':
loss.backward()
optimizer.step()
running_total_loss += loss.item() * gts.size(0)
running_classifier_loss += classifier_loss.item() * gts.size(0)
running_regressor_loss += regressor_loss.item() * gts.size(0)
pred_x = np.round(predicts.cpu().detach().numpy())
pred_y = np.round(predicts.cpu().detach().numpy())
gt_x = np.round(gts.cpu().numpy())
gt_y = np.round(gts.cpu().numpy())
# hard accuracy
if self.predict_label and self.predict_location:
running_regressor_corrects += np.sum(np.bitwise_and(
pred_x[:, -2] == gt_x[:, -2], pred_y[:, -1] == gt_y[:, -1]))
running_classifier_corrects += torch.sum(
torch.argmax(torch.abs(predicts[:, :-2]), dim=1) == gts[:, 0]).item()
elif self.predict_label:
running_classifier_corrects += torch.sum(
torch.argmax(torch.abs(predicts), dim=1) == gts[:, 0]).item()
running_regressor_corrects = 0
elif self.predict_location:
running_regressor_corrects += np.sum(np.bitwise_and(
pred_x[:, 0] == gt_x[:, -2], pred_y[:, 1] == gt_y[:, -1]))
running_classifier_corrects = 0
epoch_total_loss = running_total_loss / dataset_sizes[split]
epoch_regressor_loss = running_regressor_loss / dataset_sizes[split]
epoch_classifier_loss = running_classifier_loss / dataset_sizes[split]
epoch_regressor_acc = running_regressor_corrects / dataset_sizes[split]
epoch_classifier_acc = running_classifier_corrects / dataset_sizes[split]
if writer is not None:
writer.add_scalar(f'Loss/{split}_total', epoch_total_loss, epoch)
writer.add_scalar(f'Loss/{split}_classifier', epoch_classifier_loss, epoch)
writer.add_scalar(f'Loss/{split}_regressor', epoch_regressor_loss, epoch)
writer.add_scalar(f'Accuracy/{split}_classifier', epoch_classifier_acc, epoch)
writer.add_scalar(f'Accuracy/{split}_regressor', epoch_regressor_acc, epoch)
logging.info(f'{split.upper()} Total loss: {epoch_total_loss:.4f}, '
f'label loss: {epoch_classifier_loss:.4f}, xy loss: {epoch_regressor_loss},'
f' label acc: {epoch_classifier_acc:.4f}, xy acc: {epoch_regressor_acc}')
# deep copy the model
if self.predict_label and self.predict_location:
target_acc = epoch_regressor_acc + epoch_classifier_acc
elif self.predict_location:
target_acc = epoch_regressor_acc
else:
target_acc = epoch_classifier_acc
if split == 'val' and target_acc > best_acc:
best_acc = target_acc
best_model_wts = copy.deepcopy(model.state_dict())
self.save_checkpoint(f"ckpt.{epoch}.pth")
self.save_checkpoint(f"best_val.pth", checkpoint={"audiogoal_predictor": best_model_wts})
time_elapsed = time.time() - since
logging.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
logging.info('Best val acc: {:4f}'.format(best_acc))
if best_model_wts is not None:
model.load_state_dict(best_model_wts)