def main()

in 01-byoc/code/train.py [0:0]


def main():
    parser = ArgumentParser()
    # data or model path setting
    parser.add_argument("--csv_path", type=str, default='/DATA/hucheng/competition/official/preliminary/after_trim/meta_train.csv', help='the path of train csv file')
    parser.add_argument("--data_dir", type=str, default="/DATA/hucheng/competition/official/preliminary/after_trim/train", help="the directory of sound data")
    parser.add_argument("--save_root", type=str, default="./results", help="the root of results")
    parser.add_argument("--model_file", type=str, default="./results/final_mode.pkl", help="the root of results")
    parser.add_argument("--resume", type=str, default=None, help="the path of resume training model")
    # training parameter setting
    parser.add_argument("--model_name", type=str, default='VGGish', choices=['VGGish'], help='the algorithm we used')
    parser.add_argument("--val_split", type=float, default=0.1, help="the ratio of validation set. 0 means there's no validation dataset")
    parser.add_argument("--epochs", type=int, default=20, help="epoch number")
    parser.add_argument("--batch_size", type=int, default=128, help="batch size")
    parser.add_argument("--optimizer", type=str, default="adam", choices=["adam"])
    parser.add_argument("--scheduler", type=str, default="steplr", choices=["steplr"])
    parser.add_argument("--lr", type=float, default=0.0001, help="learning rate")
    parser.add_argument("--num_class", type=int, default=6, help="number of classes")
    parser.add_argument("--normalize", type=str, default=None, choices=[None, 'rms', 'peak'], help="normalize the input before fed into model")
    parser.add_argument("--preload", action='store_true', default=False, help="whether to convert to melspectrogram first before start training")
    # data augmentation setting
    parser.add_argument("--spec_aug", action='store_true', default=False)
    parser.add_argument("--time_drop_width", type=int, default=64)
    parser.add_argument("--time_stripes_num", type=int, default=2)
    parser.add_argument("--freq_drop_width", type=int, default=8)
    parser.add_argument("--freq_stripes_num", type=int, default=2)
    # proprocessing setting
    parser.add_argument("--sr", type=int, default=8000)
    parser.add_argument("--nfft", type=int, default=200)
    parser.add_argument("--hop", type=int, default=80)
    parser.add_argument("--mel", type=int, default=64)
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO)
    logger.info("Arguments: %s", pformat(args))

    ##################
    # config setting #
    ##################

    params = ParameterSetting(args.csv_path, args.data_dir, args.save_root, args.model_file, args.model_name, args.val_split,
                              args.epochs, args.batch_size, args.lr, args.num_class,
                              args.time_drop_width, args.time_stripes_num, args.freq_drop_width, args.freq_stripes_num,
                              args.sr, args.nfft, args.hop, args.mel, args.resume, args.normalize, args.preload,
                              args.spec_aug, args.optimizer, args.scheduler)

    if not os.path.exists(params.save_root):
        os.mkdir(params.save_root)
        print("create folder: {}".format(params.save_root))
        if not os.path.exists(os.path.join(params.save_root, 'snapshots')):
            os.mkdir(os.path.join(params.save_root, 'snapshots'))
        if not os.path.exists(os.path.join(params.save_root, 'log')):
            os.mkdir(os.path.join(params.save_root, 'log'))

    ###################
    # model preparing #
    ###################

    model = prepare_model(params)

    ##################
    # data preparing #
    ##################

    print("Preparing training/validation data...")
    dataset = SoundDataset(params)

    train_dataloader = SoundDataLoader(dataset, batch_size=params.batch_size, shuffle=True, validation_split=params.val_split, pin_memory=True)
    val_dataloader = train_dataloader.split_validation()

    dataloaders = {'train': train_dataloader, 'val': val_dataloader}
    dataset_sizes = {'train': len(train_dataloader.sampler), 'val': len(train_dataloader.valid_sampler)}
    print("train size: {}, val size: {}".format(dataset_sizes['train'], dataset_sizes['val']))

    ##################
    # model training #
    ##################

    # start to train the model
    train_model(model, params, dataloaders, dataset_sizes)