in tutorials-and-examples/tpu-examples/single-host-inference/tf/resnet50/export_resnet_model.py [0:0]
def export_resnet():
logging.info("Load TF ResNet-50 model.")
resnet = ResNetModel()
inputs = {"image": tf.random.uniform((1, 224, 224, 3), dtype=tf.float32)}
# Save the TF model.
logging.info("Save the TF model as Saved Model format.")
_CPU_MODEL_PATH = "/tmp/tf/resnet_cpu/1"
_TPU_MODEL_PATH = "/tmp/tf/resnet_tpu/1"
tf.io.gfile.makedirs(_CPU_MODEL_PATH)
tf.io.gfile.makedirs(_TPU_MODEL_PATH)
tf.saved_model.save(
obj=resnet,
export_dir=_CPU_MODEL_PATH,
signatures={"serving_default": resnet.serve.get_concrete_function()},
options=tf.saved_model.SaveOptions(
function_aliases={"tpu_func": resnet.tpu_func}
),
)
# Save a warmup request.
inputs = {"image": tf.random.uniform((1, 224, 224, 3), dtype=tf.float32)}
_EXTRA_ASSETS_DIR = "assets.extra"
_WARMUP_REQ_FILE = "tf_serving_warmup_requests"
assets_dir = os.path.join(_CPU_MODEL_PATH, _EXTRA_ASSETS_DIR)
tf.io.gfile.makedirs(assets_dir)
with tf.io.TFRecordWriter(
os.path.join(assets_dir, _WARMUP_REQ_FILE)
) as writer:
request = predict_pb2.PredictRequest()
for key, val in inputs.items():
request.inputs[key].MergeFrom(tf.make_tensor_proto(val))
log = prediction_log_pb2.PredictionLog(
predict_log=prediction_log_pb2.PredictLog(request=request)
)
writer.write(log.SerializeToString())