def main()

in sagemaker-voice-classification/notebook/train.py [0:0]


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--batch-size", type=int, default=64, help="train batch size")
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=64,
        help="test batch size",
    )
    parser.add_argument("--epochs", type=int, default=2, help="number of epochs")
    parser.add_argument("--lr", type=float, default=0.1, help="learning rate")
    parser.add_argument("--gamma", type=float, default=0.01, help="Learning rate step gamma")
    parser.add_argument("--weight-decay", type=float, default=0.0001, help="Optimizer regularization")
    parser.add_argument("--stepsize", type=int, default=5, help="Step LR size")
    parser.add_argument("--model", type=str, default="m3")
    parser.add_argument("--num-workers", type=int, default=30)
    parser.add_argument("--csv-file", type=str, default="breathing-deep-metadata.csv")
    parser.add_argument("--seed", type=int, default=1, help="seed")
    parser.add_argument("--log-interval", type=int, default=10)
    parser.add_argument("--localpath", type=str, default="data")

    # Container environment
    parser.add_argument("--model-dir", type=str, default=os.getenv("SM_MODEL_DIR", "./"))
    if os.getenv("SM_HOSTS") is not None:
        # parser.add_argument("--current-host", type=str, default=os.environ["SM_CURRENT_HOST"])
        parser.add_argument("--data-dir", type=str, default=os.environ["SM_CHANNEL_TRAINING"])
        # parser.add_argument("--num-gpus", type=int, default=os.environ["SM_NUM_GPUS"])
        # print_files_in_path(os.environ["SM_CHANNEL_TRAINING"])

    args = parser.parse_args()
    print(args)
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    # On SageMaker, data is mounted to SM_CHANNEL_TRAINING, update channel to use sample/full dataset
    if os.getenv("SM_HOSTS") is not None:
        print("Running on sagemaker")
        datapath = Path(args.data_dir)
        csv_path = datapath / args.csv_file
        file_path = datapath
    # Local, use smaller data subset for testing first
    else:
        print("Running on local")
        full_filepath = Path(__file__).resolve()
        parent_path = full_filepath.parent.parent
        csv_path = parent_path / args.localpath / "breathing-deep-metadata.csv"
        file_path = parent_path / args.localpath

    print("csv_path", csv_path)
    print("file_path", file_path)
    kwargs = {"num_workers": args.num_workers, "pin_memory": True} if torch.cuda.is_available() else {}
    print(kwargs)

    dataset = CoswareDataset(
        csv_path=csv_path,
        file_path=file_path,
        new_sr=8000,
        audio_len=20,
        sampling_ratio=5,
    )
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    print(f"train_size: {train_size}, test_size:{test_size}")
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True)

    print("Loading model:", args.model)
    if args.model == "m3":
        model = NetM3()
    else:
        model = NetM3()

    if torch.cuda.device_count() > 1:
        print("There are {} gpus".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)

    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)

    log_interval = args.log_interval

    for epoch in range(1, args.epochs + 1):
        print("Learning rate:", scheduler.get_last_lr()[0])
        train(model, epoch, train_loader, device, optimizer, log_interval)
        loss, accuracy = test(model, test_loader, device)
        scheduler.step()

    save_model(model, args.model_dir)