in aiops/Pathformer_ICLR2024/exp/exp_main.py [0:0]
def test(self, setting, test=0):
test_data, test_loader = self._get_data(flag='test')
if test:
print('loading model')
self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
preds = []
trues = []
inputx = []
folder_path = './test_results/' + setting + '/'
if not os.path.exists(folder_path):
os.makedirs(folder_path)
self.model.eval()
with torch.no_grad():
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
batch_x = batch_x.float().to(self.device)
batch_y = batch_y.float().to(self.device)
batch_x_mark = batch_x_mark.float().to(self.device)
batch_y_mark = batch_y_mark.float().to(self.device)
if self.args.use_amp:
with torch.cuda.amp.autocast():
if self.args.model=='PathFormer':
outputs, balance_loss = self.model(batch_x)
else:
outputs = self.model(batch_x)
else:
if self.args.model == 'PathFormer':
outputs, balance_loss = self.model(batch_x)
else:
outputs = self.model(batch_x)
f_dim = -1 if self.args.features == 'MS' else 0
outputs = outputs[:, -self.args.pred_len:, f_dim:]
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
outputs = outputs.detach().cpu().numpy()
batch_y = batch_y.detach().cpu().numpy()
pred = outputs # outputs.detach().cpu().numpy() # .squeeze()
true = batch_y # batch_y.detach().cpu().numpy() # .squeeze()
preds.append(pred)
trues.append(true)
inputx.append(batch_x.detach().cpu().numpy())
if i % 20 == 0:
input = batch_x.detach().cpu().numpy()
gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0)
pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0)
visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf'))
if self.args.test_flop:
test_params_flop((batch_x.shape[1], batch_x.shape[2]))
exit()
preds = np.array(preds)
trues = np.array(trues)
inputx = np.array(inputx)
preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
inputx = inputx.reshape(-1, inputx.shape[-2], inputx.shape[-1])
mae, mse, rmse, mape, mspe, rse, corr = metric(preds, trues)
print('mse:{}, mae:{}, rse:{}'.format(mse, mae, rse))
f = open("result.txt", 'a')
f.write(setting + " \n")
f.write('mse:{}, mae:{}, rse:{}'.format(mse, mae, rse))
f.write('\n')
f.write('\n')
f.close()
return