def model()

in training/distributed-training/tensorflow/multi_worker_mirrored_strategy/mnist-distributed.py [0:0]


def model(x_train, y_train, x_test, y_test, strategy):
    """Generate a simple model"""
    with strategy.scope():

        model = tf.keras.models.Sequential(
            [
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(1024, activation=tf.nn.relu),
                tf.keras.layers.Dropout(0.4),
                tf.keras.layers.Dense(10, activation=tf.nn.softmax),
            ]
        )

        model.compile(
            optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
        )

    model.fit(x_train, y_train)
    model.evaluate(x_test, y_test)

    return model