in 01-byoc/code/train.py [0:0]
def main():
parser = ArgumentParser()
# data or model path setting
parser.add_argument("--csv_path", type=str, default='/DATA/hucheng/competition/official/preliminary/after_trim/meta_train.csv', help='the path of train csv file')
parser.add_argument("--data_dir", type=str, default="/DATA/hucheng/competition/official/preliminary/after_trim/train", help="the directory of sound data")
parser.add_argument("--save_root", type=str, default="./results", help="the root of results")
parser.add_argument("--model_file", type=str, default="./results/final_mode.pkl", help="the root of results")
parser.add_argument("--resume", type=str, default=None, help="the path of resume training model")
# training parameter setting
parser.add_argument("--model_name", type=str, default='VGGish', choices=['VGGish'], help='the algorithm we used')
parser.add_argument("--val_split", type=float, default=0.1, help="the ratio of validation set. 0 means there's no validation dataset")
parser.add_argument("--epochs", type=int, default=20, help="epoch number")
parser.add_argument("--batch_size", type=int, default=128, help="batch size")
parser.add_argument("--optimizer", type=str, default="adam", choices=["adam"])
parser.add_argument("--scheduler", type=str, default="steplr", choices=["steplr"])
parser.add_argument("--lr", type=float, default=0.0001, help="learning rate")
parser.add_argument("--num_class", type=int, default=6, help="number of classes")
parser.add_argument("--normalize", type=str, default=None, choices=[None, 'rms', 'peak'], help="normalize the input before fed into model")
parser.add_argument("--preload", action='store_true', default=False, help="whether to convert to melspectrogram first before start training")
# data augmentation setting
parser.add_argument("--spec_aug", action='store_true', default=False)
parser.add_argument("--time_drop_width", type=int, default=64)
parser.add_argument("--time_stripes_num", type=int, default=2)
parser.add_argument("--freq_drop_width", type=int, default=8)
parser.add_argument("--freq_stripes_num", type=int, default=2)
# proprocessing setting
parser.add_argument("--sr", type=int, default=8000)
parser.add_argument("--nfft", type=int, default=200)
parser.add_argument("--hop", type=int, default=80)
parser.add_argument("--mel", type=int, default=64)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logger.info("Arguments: %s", pformat(args))
##################
# config setting #
##################
params = ParameterSetting(args.csv_path, args.data_dir, args.save_root, args.model_file, args.model_name, args.val_split,
args.epochs, args.batch_size, args.lr, args.num_class,
args.time_drop_width, args.time_stripes_num, args.freq_drop_width, args.freq_stripes_num,
args.sr, args.nfft, args.hop, args.mel, args.resume, args.normalize, args.preload,
args.spec_aug, args.optimizer, args.scheduler)
if not os.path.exists(params.save_root):
os.mkdir(params.save_root)
print("create folder: {}".format(params.save_root))
if not os.path.exists(os.path.join(params.save_root, 'snapshots')):
os.mkdir(os.path.join(params.save_root, 'snapshots'))
if not os.path.exists(os.path.join(params.save_root, 'log')):
os.mkdir(os.path.join(params.save_root, 'log'))
###################
# model preparing #
###################
model = prepare_model(params)
##################
# data preparing #
##################
print("Preparing training/validation data...")
dataset = SoundDataset(params)
train_dataloader = SoundDataLoader(dataset, batch_size=params.batch_size, shuffle=True, validation_split=params.val_split, pin_memory=True)
val_dataloader = train_dataloader.split_validation()
dataloaders = {'train': train_dataloader, 'val': val_dataloader}
dataset_sizes = {'train': len(train_dataloader.sampler), 'val': len(train_dataloader.valid_sampler)}
print("train size: {}, val size: {}".format(dataset_sizes['train'], dataset_sizes['val']))
##################
# model training #
##################
# start to train the model
train_model(model, params, dataloaders, dataset_sizes)