def train()

in sagemaker/tf-deploy/code/train.py [0:0]


def train(args):
    # create data loader from the train / test channels
    x_train, y_train = mnist_to_numpy(data_dir=args.train, train=True)
    x_test, y_test = mnist_to_numpy(data_dir=args.test, train=False)

    x_train, x_test = x_train.astype(np.float32), x_test.astype(np.float32)

    # normalize the inputs to mean 0 and std 1
    x_train, x_test = normalize(x_train, (1, 2)), normalize(x_test, (1, 2))

    # expand channel axis
    # tf uses depth minor convention
    x_train, x_test = np.expand_dims(x_train, axis=3), np.expand_dims(x_test, axis=3)
    
    # normalize the data to mean 0 and std 1
    train_loader = tf.data.Dataset.from_tensor_slices(
        (x_train, y_train)).shuffle(len(x_train)).batch(args.batch_size)

    test_loader = tf.data.Dataset.from_tensor_slices(
        (x_test, y_test)).batch(args.batch_size)

    model = SmallConv()
    model.compile()
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    optimizer = tf.keras.optimizers.Adam(
            learning_rate=args.learning_rate, 
            beta_1=args.beta_1,
            beta_2=args.beta_2
            )


    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

    test_loss = tf.keras.metrics.Mean(name='test_loss')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')


    @tf.function
    def train_step(images, labels):
        with tf.GradientTape() as tape:
            predictions = model(images, training=True)
            loss = loss_fn(labels, predictions)
        grad = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grad, model.trainable_variables))
        
        train_loss(loss)
        train_accuracy(labels, predictions)
        return 
        
    @tf.function
    def test_step(images, labels):
        predictions = model(images, training=False)
        t_loss = loss_fn(labels, predictions)
        test_loss(t_loss)
        test_accuracy(labels, predictions)
        return
    
    print("Training starts ...")
    for epoch in range(args.epochs):
        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()
        
        for batch, (images, labels) in enumerate(train_loader):
            train_step(images, labels)
        
        for images, labels in test_loader:
            test_step(images, labels)
        
        print(
            f'Epoch {epoch + 1}, '
            f'Loss: {train_loss.result()}, '
            f'Accuracy: {train_accuracy.result() * 100}, '
            f'Test Loss: {test_loss.result()}, '
            f'Test Accuracy: {test_accuracy.result() * 100}'
        )

    # Save the model
    # A version number is needed for the serving container
    # to load the model
    version = '00000000'
    ckpt_dir = os.path.join(args.model_dir, version)
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    model.save(ckpt_dir)
    return