def export_resnet()

in tutorials-and-examples/tpu-examples/single-host-inference/tf/resnet50/export_resnet_model.py [0:0]


def export_resnet():
  logging.info("Load TF ResNet-50 model.")
  resnet = ResNetModel()
  inputs = {"image": tf.random.uniform((1, 224, 224, 3), dtype=tf.float32)}
  # Save the TF model.
  logging.info("Save the TF model as Saved Model format.")
  _CPU_MODEL_PATH = "/tmp/tf/resnet_cpu/1"
  _TPU_MODEL_PATH = "/tmp/tf/resnet_tpu/1"
  tf.io.gfile.makedirs(_CPU_MODEL_PATH)
  tf.io.gfile.makedirs(_TPU_MODEL_PATH)
  tf.saved_model.save(
      obj=resnet,
      export_dir=_CPU_MODEL_PATH,
      signatures={"serving_default": resnet.serve.get_concrete_function()},
      options=tf.saved_model.SaveOptions(
          function_aliases={"tpu_func": resnet.tpu_func}
      ),
  )
  # Save a warmup request.
  inputs = {"image": tf.random.uniform((1, 224, 224, 3), dtype=tf.float32)}
  _EXTRA_ASSETS_DIR = "assets.extra"
  _WARMUP_REQ_FILE = "tf_serving_warmup_requests"
  assets_dir = os.path.join(_CPU_MODEL_PATH, _EXTRA_ASSETS_DIR)
  tf.io.gfile.makedirs(assets_dir)
  with tf.io.TFRecordWriter(
      os.path.join(assets_dir, _WARMUP_REQ_FILE)
  ) as writer:
    request = predict_pb2.PredictRequest()
    for key, val in inputs.items():
      request.inputs[key].MergeFrom(tf.make_tensor_proto(val))
    log = prediction_log_pb2.PredictionLog(
        predict_log=prediction_log_pb2.PredictLog(request=request)
    )
    writer.write(log.SerializeToString())