in sagemaker-voice-classification/notebook/train.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int, default=64, help="train batch size")
parser.add_argument(
"--test-batch-size",
type=int,
default=64,
help="test batch size",
)
parser.add_argument("--epochs", type=int, default=2, help="number of epochs")
parser.add_argument("--lr", type=float, default=0.1, help="learning rate")
parser.add_argument("--gamma", type=float, default=0.01, help="Learning rate step gamma")
parser.add_argument("--weight-decay", type=float, default=0.0001, help="Optimizer regularization")
parser.add_argument("--stepsize", type=int, default=5, help="Step LR size")
parser.add_argument("--model", type=str, default="m3")
parser.add_argument("--num-workers", type=int, default=30)
parser.add_argument("--csv-file", type=str, default="breathing-deep-metadata.csv")
parser.add_argument("--seed", type=int, default=1, help="seed")
parser.add_argument("--log-interval", type=int, default=10)
parser.add_argument("--localpath", type=str, default="data")
# Container environment
parser.add_argument("--model-dir", type=str, default=os.getenv("SM_MODEL_DIR", "./"))
if os.getenv("SM_HOSTS") is not None:
# parser.add_argument("--current-host", type=str, default=os.environ["SM_CURRENT_HOST"])
parser.add_argument("--data-dir", type=str, default=os.environ["SM_CHANNEL_TRAINING"])
# parser.add_argument("--num-gpus", type=int, default=os.environ["SM_NUM_GPUS"])
# print_files_in_path(os.environ["SM_CHANNEL_TRAINING"])
args = parser.parse_args()
print(args)
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
# On SageMaker, data is mounted to SM_CHANNEL_TRAINING, update channel to use sample/full dataset
if os.getenv("SM_HOSTS") is not None:
print("Running on sagemaker")
datapath = Path(args.data_dir)
csv_path = datapath / args.csv_file
file_path = datapath
# Local, use smaller data subset for testing first
else:
print("Running on local")
full_filepath = Path(__file__).resolve()
parent_path = full_filepath.parent.parent
csv_path = parent_path / args.localpath / "breathing-deep-metadata.csv"
file_path = parent_path / args.localpath
print("csv_path", csv_path)
print("file_path", file_path)
kwargs = {"num_workers": args.num_workers, "pin_memory": True} if torch.cuda.is_available() else {}
print(kwargs)
dataset = CoswareDataset(
csv_path=csv_path,
file_path=file_path,
new_sr=8000,
audio_len=20,
sampling_ratio=5,
)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
print(f"train_size: {train_size}, test_size:{test_size}")
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True)
print("Loading model:", args.model)
if args.model == "m3":
model = NetM3()
else:
model = NetM3()
if torch.cuda.device_count() > 1:
print("There are {} gpus".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)
log_interval = args.log_interval
for epoch in range(1, args.epochs + 1):
print("Learning rate:", scheduler.get_last_lr()[0])
train(model, epoch, train_loader, device, optimizer, log_interval)
loss, accuracy = test(model, test_loader, device)
scheduler.step()
save_model(model, args.model_dir)