in tutorials-and-examples/tpu-examples/single-host-inference/jax/bert/bert_request.py [0:0]
def send_request():
logging.info("Establish the gRPC connection with the model server.")
_PREDICTION_SERVICE_HOST = str(args.external_ip)
_GRPC_PORT = 8500
options = [
("grpc.max_send_message_length", 512 * 1024 * 1024),
("grpc.max_receive_message_length", 512 * 1024 * 1024),
]
channel = grpc.insecure_channel(
f"{_PREDICTION_SERVICE_HOST}:{_GRPC_PORT}", options=options
)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
_MAX_INPUT_SIZE = 64
_BERT_BASE_UNCASED = "bert-base-uncased"
prompt = [
"The capital of France is [MASK].",
"Hello my name [MASK] Jhon, how can I [MASK] you?",
]
# You can also embed the tokenization in the TF2 model, please follow
# https://github.com/google/jax/tree/main/jax/experimental/jax2tf#incomplete-tensorflow-data-type-coverage
tokenizer = AutoTokenizer.from_pretrained(
_BERT_BASE_UNCASED, model_max_length=_MAX_INPUT_SIZE
)
logging.info("Tokenize the input sentences.")
inputs = tokenizer(
prompt, return_tensors="tf", padding="max_length", truncation=True
)
request = predict_pb2.PredictRequest()
request.model_spec.name = "bert"
request.model_spec.signature_name = (
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
)
for key, val in inputs.items():
request.inputs[key].MergeFrom(tf.make_tensor_proto(val))
logging.info("Send the request to the model server.")
res = stub.Predict(request)
logging.info("Predict completed.")
outputs = {
name: tf.io.parse_tensor(serialized.SerializeToString(), serialized.dtype)
for name, serialized in res.outputs.items()
}
out_argmaxes = tf.math.argmax(
outputs["logits"],
axis=-1,
output_type=tf.dtypes.int32,
)
# Undo padding and print the result.
length = tf.math.reduce_sum(inputs["attention_mask"], axis=1).numpy()
for index in range(len(prompt)):
result = tokenizer.decode(out_argmaxes[index][: length[index]])
logging.info(f'For input "{prompt[index]}", the result is "{result}".')