in tutorials-and-examples/tpu-examples/single-host-inference/jax/bert/export_bert_model.py [0:0]
def export_bert_base_uncased():
_MAX_INPUT_SIZE = 64
# The Bert implementation is based on `bert-base-uncased` HuggingFace's
# [blog](https://huggingface.co/bert-base-uncased).
_BERT_BASE_UNCASED = "bert-base-uncased"
logging.info("Load Flax Bert model.")
model = FlaxBertForMaskedLM.from_pretrained(_BERT_BASE_UNCASED)
logging.info("Use jax2tf to convert the Flax model.")
# Converter the Jax model to TF2 model
def predict_fn(params, input_ids, attention_mask, token_type_ids):
return model.__call__(
params=params,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
params_vars = tf.nest.map_structure(tf.Variable, model.params)
tf_predict = tf.function(
lambda input_ids, attention_mask, token_type_ids: jax2tf.convert(
predict_fn,
enable_xla=True,
with_gradient=False,
polymorphic_shapes=[
None,
f"(b, {_MAX_INPUT_SIZE})",
f"(b, {_MAX_INPUT_SIZE})",
f"(b, {_MAX_INPUT_SIZE})",
],
)(params_vars, input_ids, attention_mask, token_type_ids),
input_signature=[
tf.TensorSpec(
shape=(None, _MAX_INPUT_SIZE), dtype=tf.int32, name="input_ids"
),
tf.TensorSpec(
shape=(None, _MAX_INPUT_SIZE),
dtype=tf.int32,
name="attention_mask",
),
tf.TensorSpec(
shape=(None, _MAX_INPUT_SIZE),
dtype=tf.int32,
name="token_type_ids",
),
],
autograph=False,
)
tf_model = tf.Module()
tf_model.tf_predict = tf_predict
tf_model._variables = tf.nest.flatten(params_vars)
# Save the TF model.
logging.info("Save the TF model.")
_CPU_MODEL_PATH = "/tmp/jax/bert_cpu/1"
_TPU_MODEL_PATH = "/tmp/jax/bert_tpu/1"
tf.io.gfile.makedirs(_CPU_MODEL_PATH)
tf.io.gfile.makedirs(_TPU_MODEL_PATH)
tf.saved_model.save(
obj=tf_model,
export_dir=_CPU_MODEL_PATH,
signatures={
"serving_default": tf_model.tf_predict.get_concrete_function()
},
options=tf.saved_model.SaveOptions(
function_aliases={"tpu_func": tf_model.tf_predict}
),
)
# Save a warmup request.
_TEXT = ["Today is a [MASK] day.", "My dog is [MASK]."]
tokenizer = AutoTokenizer.from_pretrained(
_BERT_BASE_UNCASED, model_max_length=_MAX_INPUT_SIZE
)
tf_inputs = tokenizer(
_TEXT, return_tensors="tf", padding="max_length", truncation=True
)
_EXTRA_ASSETS_DIR = "assets.extra"
_WARMUP_REQ_FILE = "tf_serving_warmup_requests"
assets_dir = os.path.join(_CPU_MODEL_PATH, _EXTRA_ASSETS_DIR)
tf.io.gfile.makedirs(assets_dir)
with tf.io.TFRecordWriter(
os.path.join(assets_dir, _WARMUP_REQ_FILE)
) as writer:
request = predict_pb2.PredictRequest()
for key, val in tf_inputs.items():
request.inputs[key].MergeFrom(tf.make_tensor_proto(val))
log = prediction_log_pb2.PredictionLog(
predict_log=prediction_log_pb2.PredictLog(request=request)
)
writer.write(log.SerializeToString())