in community-content/tf_keras_text_classification_distributed_single_worker_gpus_with_gcloud_local_run_and_vertex_sdk/trainer/task.py [0:0]
def main():
args = parse_args()
local_data_dir = './tmp/data'
local_model_dir = './tmp/model'
local_checkpoint_dir = './tmp/checkpoints'
local_tensorboard_log_dir = './tmp/logs'
model_dir = args.model_dir or local_model_dir
tensorboard_log_dir = args.tensorboard_log_dir or local_tensorboard_log_dir
checkpoint_dir = args.checkpoint_dir or local_checkpoint_dir
gs_prefix = 'gs://'
gcsfuse_prefix = '/gcs/'
if model_dir and model_dir.startswith(gs_prefix):
model_dir = model_dir.replace(gs_prefix, gcsfuse_prefix)
if tensorboard_log_dir and tensorboard_log_dir.startswith(gs_prefix):
tensorboard_log_dir = tensorboard_log_dir.replace(gs_prefix, gcsfuse_prefix)
if checkpoint_dir and checkpoint_dir.startswith(gs_prefix):
checkpoint_dir = checkpoint_dir.replace(gs_prefix, gcsfuse_prefix)
class_names = ['csharp', 'java', 'javascript', 'python']
class_indices = dict(zip(class_names, range(len(class_names))))
num_classes = len(class_names)
print(f' class names: {class_names}')
print(f' class indices: {class_indices}')
print(f' num classes: {num_classes}')
strategy = distribution_utils.get_distribution_mirrored_strategy(
num_gpus=args.num_gpus)
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
global_batch_size = args.batch_size * strategy.num_replicas_in_sync
print(f'Global batch size: {global_batch_size}')
dataset_dir = download_data(local_data_dir)
raw_train_ds, raw_val_ds, raw_test_ds = load_dataset(dataset_dir, global_batch_size)
vectorize_layer = TextVectorization(
max_tokens=VOCAB_SIZE,
output_mode='int',
output_sequence_length=MAX_SEQUENCE_LENGTH)
train_text = raw_train_ds.map(lambda text, labels: text)
vectorize_layer.adapt(train_text)
print('The vectorize_layer is adapted')
def vectorize_text(text, label):
text = tf.expand_dims(text, -1)
return vectorize_layer(text), label
# Retrieve a batch (of 32 reviews and labels) from the dataset
text_batch, label_batch = next(iter(raw_train_ds))
first_question, first_label = text_batch[0], label_batch[0]
print("Question", first_question)
print("Label", first_label)
print("Vectorized question:", vectorize_text(first_question, first_label)[0])
train_ds = raw_train_ds.map(vectorize_text)
val_ds = raw_val_ds.map(vectorize_text)
test_ds = raw_test_ds.map(vectorize_text)
AUTOTUNE = tf.data.AUTOTUNE
def configure_dataset(dataset):
return dataset.cache().prefetch(buffer_size=AUTOTUNE)
train_ds = configure_dataset(train_ds)
val_ds = configure_dataset(val_ds)
test_ds = configure_dataset(test_ds)
print('Build model')
loss = losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer = 'adam'
metrics = ['accuracy']
with strategy.scope():
model = build_model(
num_classes=num_classes,
loss=loss,
optimizer=optimizer,
metrics=metrics,
)
train(
model=model,
train_dataset=train_ds,
validation_dataset=val_ds,
epochs=args.epochs,
tensorboard_log_dir=tensorboard_log_dir,
checkpoint_dir=checkpoint_dir
)
test_loss, test_accuracy = model.evaluate(test_ds)
print("Int model accuracy: {:2.2%}".format(test_accuracy))
with strategy.scope():
export_model = tf.keras.Sequential(
[vectorize_layer, model,
layers.Activation('softmax')])
export_model.compile(
loss=losses.SparseCategoricalCrossentropy(from_logits=False),
optimizer='adam',
metrics=['accuracy'])
loss, accuracy = export_model.evaluate(raw_test_ds)
print("Accuracy: {:2.2%}".format(accuracy))
model_path = os.path.join(model_dir, str(args.model_version))
model.save(model_path)
print(f'Model version {args.model_version} is saved to {model_dir}')
print(f'Tensorboard logs are saved to: {tensorboard_log_dir}')
print(f'Checkpoints are saved to: {checkpoint_dir}')
return