in tutorials-and-examples/tpu-examples/single-host-inference/jax/stable-diffusion/export_stable_diffusion_model.py [0:0]
def export_stable_diffusion():
# The Stable Diffusion implementation is from Stability AI
# [blog](https://huggingface.co/CompVis/stable-diffusion-v1-4).
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jax.numpy.bfloat16
)
_TOKEN_LEN = pipeline.tokenizer.model_max_length
logging.info("Use jax2tf to convert the Flax model.")
# Converter the Jax model to TF2 model
def predict_fn(params, prompt_ids):
return pipeline._generate(
prompt_ids=prompt_ids,
params=params,
# Default values:
# `https://github.com/huggingface/diffusers/blob/v0.8.0/src/diffusers/
# `pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py#L246`.
prng_seed=jax.random.PRNGKey(0),
num_inference_steps=50,
height=512,
width=512,
guidance_scale=7.5,
)
params_flat, params_tree = jax.tree_util.tree_flatten(params)
params_vars_flat = tuple(tf.Variable(p) for p in params_flat)
params_vars = jax.tree_util.tree_unflatten(params_tree, params_vars_flat)
tf_predict = tf.function(
lambda prompt_ids: jax2tf.convert(
predict_fn, enable_xla=True, with_gradient=False, native_serialization_platforms=['tpu']
)(params_vars, prompt_ids),
input_signature=[
tf.TensorSpec(
shape=(1, _TOKEN_LEN), dtype=tf.int32, name="prompt_ids"
),
],
autograph=False,
)
tf_model = tf.Module()
tf_model.tf_predict = tf_predict
tf_model._variables = params_vars_flat
# Save the TF model.
logging.info("Save the TF model.")
_CPU_MODEL_PATH = "/tmp/jax/stable_diffusion_cpu/1"
_TPU_MODEL_PATH = "/tmp/jax/stable_diffusion_tpu/1"
tf.io.gfile.makedirs(_CPU_MODEL_PATH)
tf.io.gfile.makedirs(_TPU_MODEL_PATH)
tf.saved_model.save(
obj=tf_model,
export_dir=_CPU_MODEL_PATH,
signatures={
"serving_default": tf_model.tf_predict.get_concrete_function()
},
options=tf.saved_model.SaveOptions(
function_aliases={"tpu_func": tf_model.tf_predict}
),
)
# Save a warmup request.
prompt = "Labrador in the style of Hokusai"
prompt_ids = pipeline.prepare_inputs(prompt)
tf_inputs = dict()
tf_inputs["prompt_ids"] = tf.constant(prompt_ids, dtype=tf.int32)
_EXTRA_ASSETS_DIR = "assets.extra"
_WARMUP_REQ_FILE = "tf_serving_warmup_requests"
assets_dir = os.path.join(_CPU_MODEL_PATH, _EXTRA_ASSETS_DIR)
tf.io.gfile.makedirs(assets_dir)
with tf.io.TFRecordWriter(
os.path.join(assets_dir, _WARMUP_REQ_FILE)
) as writer:
request = predict_pb2.PredictRequest()
for key, val in tf_inputs.items():
request.inputs[key].MergeFrom(tf.make_tensor_proto(val))
log = prediction_log_pb2.PredictionLog(
predict_log=prediction_log_pb2.PredictLog(request=request)
)
writer.write(log.SerializeToString())