def run()

in sagemaker/source/train.py [0:0]


def run():
    torch.backends.cudnn.enabled = False
    if "SM_HPS" in os.environ:
        is_sm_mode = True
    else:
        is_sm_mode = False
        
    args = parse_args()
    
    assert len(json.loads(args.sensor_headers)) == json.loads(args.conv_channels)[-1], "The last conv filter must be equal the the number of sensor_headers"
        
    if 'SM_CHANNEL_TRAIN' in os.environ:
        train_path = os.path.join(
            os.environ['SM_CHANNEL_TRAIN'], 
            os.path.basename(args.train_input_filename))
    else:
        train_path = args.train_input_filename
        
    if 'SM_CHANNEL_TRAIN' in os.environ:
        test_path = os.path.join(
            os.environ['SM_CHANNEL_TEST'], 
            os.path.basename(args.test_input_filename))
    else:
        test_path = args.test_input_filename
        
    if 'SM_OUTPUT_DATA_DIR' in os.environ:
        output_path = os.environ['SM_OUTPUT_DATA_DIR']
        output_path = os.path.join(output_path, "output")
    else:
        output_path = args.output_path
        
    if not os.path.isdir(output_path):
        os.mkdir(output_path)

    train_ds = PMDataset_torch(
        train_path,
        target_column=args.target_column,
        standardize=True,
        sensor_headers=json.loads(args.sensor_headers))
    test_ds = PMDataset_torch(
        test_path,
        target_column=args.target_column,
        standardize=True,
        sensor_headers=json.loads(args.sensor_headers))
        
    batch_size = args.batch_size
    class_labels = torch.tensor(train_ds.labels)
    ss = StratifiedSampler(class_labels, batch_size)
    train_dl = torch.utils.data.DataLoader(
        train_ds, 
        batch_size,
        num_workers=multiprocessing.cpu_count()-1, 
        shuffle=False, 
        sampler=ss)
    test_dl = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=True)
    
    net = Network(num_features=len(json.loads(args.sensor_headers)),
                  fc_hidden_units=json.loads(args.fc_hidden_units),
                  conv_channels=json.loads(args.conv_channels),
                  dropout_strength=args.dropout)
    net = net.to(device)

    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
    critereon = torch.nn.CrossEntropyLoss()
    for e in range(args.epochs):
        train_loss, train_acc, train_auc = run_epoch(net, train_dl, optimizer, 
                                                     critereon, is_train=True)
        test_loss, test_acc, test_auc = run_epoch(net, test_dl, optimizer,
                                                  critereon, is_train=False)
        
        if is_sm_mode:
            print("Epoch: {}".format(e))
            print("Train loss: {:0.4f}".format(train_loss))
            print("Train acc: {:0.4f}".format(train_acc))
            print("Train auc: {:0.4f}".format(train_auc))
            print("Test loss: {:0.4f}".format(test_loss))
            print("Test acc: {:0.4f}".format(test_acc))
            print("Test auc: {:0.4f}".format(test_auc))
        else:
            print("{} train loss: {:0.4f} acc {:0.4f} auc {:0.4f}|".format(e, train_loss, train_acc, train_auc), end="")
            print("test loss {:0.4f} acc {:0.4f} auc {:0.4f}".format(test_loss, test_acc, test_auc))
        
        if e % 20 == 0:
            torch.save(
                {"net": net.state_dict(),
                 "sensor_headers": json.loads(args.sensor_headers),
                 "fc_hidden_units": json.loads(args.fc_hidden_units),
                 "conv_channels": json.loads(args.conv_channels)}, 
                os.path.join(output_path, "net.pth"))