def main()

in 01-byoc/code/test.py [0:0]


def main():
    parser = ArgumentParser()
    # arguments for test
    parser.add_argument("--test_csv", type=str, default='/DATA/hucheng/competition/official/preliminary/after_trim/meta_public_test.csv')
    parser.add_argument("--data_dir", type=str, default="/DATA/hucheng/competition/official/preliminary/after_trim/public_test", help="the directory of sound data")
    parser.add_argument("--model_name", type=str, default='VGGish', choices=['VGGish'], help='the algorithm we used')
    parser.add_argument("--model_path", nargs="+", default=['model.pkl'])
    parser.add_argument("--batch_size", type=int, default=128, help="the batch size")
    parser.add_argument("--threshold", type=float, default=None)
    parser.add_argument("--num_class", type=int, default=6, help="number of classes")
    parser.add_argument("--saved_root", type=str, default='results/test', help="the path of test results.")
    parser.add_argument("--saved_name", type=str, default='test_results', help="the prefix of test files")
    # proprocessing setting
    parser.add_argument("--sr", type=int, default=8000)
    parser.add_argument("--nfft", type=int, default=200)
    parser.add_argument("--hop", type=int, default=80)
    parser.add_argument("--mel", type=int, default=64)
    parser.add_argument("--normalize", type=str, default=None, choices=[None, 'rms', 'peak'], help="normalize the input before fed into model")
    parser.add_argument("--preload", action='store_true', default=False)
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO)
    logger.info("Arguments: %s", pformat(args))
    params = ParameterSetting(csv_path=args.test_csv, data_dir=args.data_dir, batch_size=args.batch_size, num_class=args.num_class, sr=args.sr,
                              nfft=args.nfft, hop=args.hop, mel=args.mel, normalize=args.normalize, preload=args.preload)

    ###################
    # model preparing #
    ###################
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = None
    if args.model_name == 'VGGish':
        model = VGGish(params)

    ##################
    # data preparing #
    ##################

    dataset = SoundDataset_test(params)
    dataloader = DataLoader(dataset, batch_size=params.batch_size, shuffle=False)
    print("the number of wavfiles : {}".format(len(dataset)))

    ##################
    #  test the file #
    ##################

    for model_idx, model_name in enumerate(args.model_path):
        model.load_state_dict(torch.load(model_name))
        model.eval()
        model = model.to(device)

        y_pred, y_true, y_prob = [], [], []
        with torch.no_grad():
            since = time.time()
            for batch_idx, (spec, gt) in tqdm(enumerate(dataloader)):
                spec = spec.to(device)
                outputs = model(spec)
                outputs = torch.nn.functional.softmax(outputs, dim=1)
                _, preds = torch.max(outputs, 1)

                pred_label = preds.cpu().detach().numpy()
                outputs = outputs.cpu().detach().numpy()
                gt = gt.data.cpu().detach().numpy()

                y_true.extend(gt)
                y_pred.extend(pred_label)
                y_prob.extend(outputs)

            time_elapsed = time.time() - since
            print('test complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        print(y_prob[:5])
        print(cfm(y_true, y_pred))
        print(classification_report(y_true, y_pred))
        print(roc_auc(y_true, y_prob))


        if not os.path.exists(args.saved_root):
            os.mkdir(args.saved_root)

        with open(os.path.join(args.saved_root, "{}_{}.txt".format(args.saved_name, model_idx)), 'w') as f:
            f.write(str(cfm(y_true, y_pred))+"\n")
            f.write(classification_report(y_true, y_pred)+"\n")
            f.write("roc auc score: "+str(roc_auc(y_true, y_prob))+"\n")