attacks/tf_ace_get_pwned.py (12 lines of code) (raw):
import base64
import json
import h5py
import tensorflow as tf
new_model = tf.keras.models.load_model("tf.h5")
print("Transformers is not vulnerable to this, as it uses h5 directly.")
print("Keras uses a pickled code of the function within the `h5` attrs of the file")
print("Let's show you the marshalled code")
with h5py.File("tf_ace.h5") as f:
data = json.loads(f.attrs["model_config"])
print(base64.b64decode(data["config"]["layers"][-1]["config"]["function"][0]))
pass