in 01-byoc/code/test.py [0:0]
def main():
parser = ArgumentParser()
# arguments for test
parser.add_argument("--test_csv", type=str, default='/DATA/hucheng/competition/official/preliminary/after_trim/meta_public_test.csv')
parser.add_argument("--data_dir", type=str, default="/DATA/hucheng/competition/official/preliminary/after_trim/public_test", help="the directory of sound data")
parser.add_argument("--model_name", type=str, default='VGGish', choices=['VGGish'], help='the algorithm we used')
parser.add_argument("--model_path", nargs="+", default=['model.pkl'])
parser.add_argument("--batch_size", type=int, default=128, help="the batch size")
parser.add_argument("--threshold", type=float, default=None)
parser.add_argument("--num_class", type=int, default=6, help="number of classes")
parser.add_argument("--saved_root", type=str, default='results/test', help="the path of test results.")
parser.add_argument("--saved_name", type=str, default='test_results', help="the prefix of test files")
# proprocessing setting
parser.add_argument("--sr", type=int, default=8000)
parser.add_argument("--nfft", type=int, default=200)
parser.add_argument("--hop", type=int, default=80)
parser.add_argument("--mel", type=int, default=64)
parser.add_argument("--normalize", type=str, default=None, choices=[None, 'rms', 'peak'], help="normalize the input before fed into model")
parser.add_argument("--preload", action='store_true', default=False)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logger.info("Arguments: %s", pformat(args))
params = ParameterSetting(csv_path=args.test_csv, data_dir=args.data_dir, batch_size=args.batch_size, num_class=args.num_class, sr=args.sr,
nfft=args.nfft, hop=args.hop, mel=args.mel, normalize=args.normalize, preload=args.preload)
###################
# model preparing #
###################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = None
if args.model_name == 'VGGish':
model = VGGish(params)
##################
# data preparing #
##################
dataset = SoundDataset_test(params)
dataloader = DataLoader(dataset, batch_size=params.batch_size, shuffle=False)
print("the number of wavfiles : {}".format(len(dataset)))
##################
# test the file #
##################
for model_idx, model_name in enumerate(args.model_path):
model.load_state_dict(torch.load(model_name))
model.eval()
model = model.to(device)
y_pred, y_true, y_prob = [], [], []
with torch.no_grad():
since = time.time()
for batch_idx, (spec, gt) in tqdm(enumerate(dataloader)):
spec = spec.to(device)
outputs = model(spec)
outputs = torch.nn.functional.softmax(outputs, dim=1)
_, preds = torch.max(outputs, 1)
pred_label = preds.cpu().detach().numpy()
outputs = outputs.cpu().detach().numpy()
gt = gt.data.cpu().detach().numpy()
y_true.extend(gt)
y_pred.extend(pred_label)
y_prob.extend(outputs)
time_elapsed = time.time() - since
print('test complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print(y_prob[:5])
print(cfm(y_true, y_pred))
print(classification_report(y_true, y_pred))
print(roc_auc(y_true, y_prob))
if not os.path.exists(args.saved_root):
os.mkdir(args.saved_root)
with open(os.path.join(args.saved_root, "{}_{}.txt".format(args.saved_name, model_idx)), 'w') as f:
f.write(str(cfm(y_true, y_pred))+"\n")
f.write(classification_report(y_true, y_pred)+"\n")
f.write("roc auc score: "+str(roc_auc(y_true, y_prob))+"\n")