envelope-encryption-sample/python-cli/cli.py (198 lines of code) (raw):
#!/usr/bin/env python3
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from absl import flags, app
from base64 import b64encode, b64decode
from google.cloud import kms
from cryptography.fernet import Fernet
FLAGS = flags.FLAGS
flags.DEFINE_enum(
"mode",
None,
["generate", "encrypt", "decrypt"],
"The operation to perform.",
)
flags.DEFINE_string(
"wrapped_key_path", None, "Path to the wrapped used for encryption."
)
flags.DEFINE_string(
"kek_name", None, "The Cloud KMS key name of the key encryption key."
)
flags.DEFINE_string("keyring_name", None, "The Cloud KMS keyring name.")
flags.DEFINE_string("project_id", None, "The GCP Project ID.")
flags.DEFINE_string("location", None, "The GCP location name.")
flags.DEFINE_integer(
"num_bytes",
32,
"The number of bytes that Data Encryption Key should have.",
)
flags.DEFINE_string("input", None, "Path to the input file.")
flags.DEFINE_string("output", None, "Path to the output file.")
def crc32c(data: bytes) -> int:
"""
Calculates the CRC32C checksum of the provided data.
Args:
data: the bytes over which the checksum should be calculated.
Returns:
An int representing the CRC32C checksum of the provided bytes.
"""
import crcmod # type: ignore
crc32c_fun = crcmod.predefined.mkPredefinedCrcFun("crc-32c")
return crc32c_fun(data)
def generate_random_bytes(
project_id: str, location: str, num_bytes: int
) -> bytes:
"""
Generate random bytes with entropy sourced from the given location.
Args:
project_id (string): Google Cloud project ID.
location (string): Cloud KMS location.
num_bytes (integer): number of bytes of random data.
Returns:
bytes: Encrypted ciphertext.
"""
# Create the client.
client = kms.KeyManagementServiceClient()
# Build the location name.
location_name = client.common_location_path(project_id, location)
# Call the API.
protection_level = kms.ProtectionLevel.HSM
random_bytes_response = client.generate_random_bytes(
request={
"location": location_name,
"length_bytes": num_bytes,
"protection_level": protection_level,
}
)
return dict(
data=random_bytes_response.data,
crc32c=random_bytes_response.data_crc32c,
)
def gcp_encrypt_symmetric(
project_id: str,
location: str,
keyring_name: str,
kek_name: str,
plaintext: str,
) -> bytes:
"""
Encrypt plaintext using a symmetric key stored in GCP KMS.
Args:
project_id (string): Google Cloud project ID.
location (string): Cloud KMS location.
keyring_name (string): Name of the Cloud KMS key ring.
kek_name (string): Name of the key to use.
plaintext (string): message to encrypt
Returns:
bytes: Encrypted ciphertext.
"""
# Convert the plaintext to bytes.
plaintext_bytes = plaintext.encode("utf-8")
# Optional, but recommended: compute plaintext's CRC32C.
# See crc32c() function defined below.
plaintext_crc32c = crc32c(plaintext_bytes)
# Create the client.
client = kms.KeyManagementServiceClient()
# Build the key name.
kek_name = client.crypto_key_path(
project_id, location, keyring_name, kek_name
)
# Call the API.
encrypt_response = client.encrypt(
request={
"name": kek_name,
"plaintext": plaintext_bytes,
"plaintext_crc32c": plaintext_crc32c,
}
)
# Optional, but recommended:
# perform integrity verification on encrypt_response.
# For more details on ensuring E2E in-transit
# integrity to and from Cloud KMS visit:
# https://cloud.google.com/kms/docs/data-integrity-guidelines
if not encrypt_response.verified_plaintext_crc32c:
raise Exception(
"The request sent to the server was corrupted in-transit."
)
if not encrypt_response.ciphertext_crc32c == crc32c(
encrypt_response.ciphertext
):
raise Exception(
"The response received from the server was corrupted in-transit."
)
# End integrity verification
return encrypt_response
def gcp_decrypt_symmetric(
project_id: str,
location: str,
keyring_name: str,
kek_name: str,
ciphertext: bytes,
) -> kms.DecryptResponse:
"""
Decrypt the ciphertext using the symmetric key stored in GCP KMS
Args:
project_id (string): Google Cloud project ID.
location (string): Cloud KMS location.
keyring_name (string): Name of the Cloud KMS key ring.
kek_name (string): Name of the key to use.
ciphertext (bytes): Encrypted bytes to decrypt.
Returns:
DecryptResponse: Response including plaintext.
"""
# Create the client.
client = kms.KeyManagementServiceClient()
# Build the key name.
kek_name = client.crypto_key_path(
project_id, location, keyring_name, kek_name
)
# Optional, but recommended: compute ciphertext's CRC32C.
# See crc32c() function defined below.
ciphertext_crc32c = crc32c(ciphertext)
# Call the API.
decrypt_response = client.decrypt(
request={
"name": kek_name,
"ciphertext": ciphertext,
"ciphertext_crc32c": ciphertext_crc32c,
}
)
# Optional, but recommended:
# perform integrity verification on decrypt_response.
# For more details on ensuring E2E in-transit integrity
# to and from Cloud KMS visit:
# https://cloud.google.com/kms/docs/data-integrity-guidelines
if not decrypt_response.plaintext_crc32c == crc32c(
decrypt_response.plaintext
):
raise Exception(
"The response received from the server was corrupted in-transit."
)
# End integrity verification
return decrypt_response
def local_encrypt_symmetric(
data_encryption_key: bytes, plaintext: str
) -> bytes:
"""
Encrypt plaintext using a symmetric key.
Fernet uses AES128-CBC + HMAC-SHA256 behind the scenes.
Args:
data_encryption_key (bytes): DEK bytes to be used on encrypt process.
plaintext (string): message to encrypt
Returns:
dict: base64 encrypted ciphertext.
"""
f = Fernet(data_encryption_key)
return f.encrypt(plaintext.encode())
def local_decrypt_symmetric(
data_encryption_key: bytes, ciphertext: bytes
) -> bytes:
"""
Decrypt ciphertext using a symmetric key.
Fernet uses AES128-CBC + HMAC-SHA256 behind the scenes.
Args:
data_encryption_key (bytes): DEK bytes to be used on encrypt process.
ciphertext (bytes): ciphertext to decrypt
Returns:
bytes: decrypted plaintext.
"""
f = Fernet(data_encryption_key)
return f.decrypt(ciphertext)
def save_json_to_file(json_data: str, file_path: str) -> None:
"""
Save a JSON object to a file.
Parameters:
json_data (dict): The JSON object to save.
file_path (str): The path to store the JSON object.
"""
try:
with open(file_path, "w") as file:
json.dump(json_data, file, indent=4)
print(f"JSON data successfully saved to {file_path}")
except Exception as e:
print(f"An error occurred while saving JSON data to file: {e}")
raise e
def load_json_from_file(file_path: str) -> dict:
"""
Load a JSON object from a file.
Parameters:
file_path (str): The path to the file where the JSON object is stored.
Returns:
dict: The JSON object loaded from the file.
"""
try:
with open(file_path, "r") as file:
json_data = json.load(file)
print(f"JSON data successfully loaded from {file_path}")
return json_data
except Exception as e:
print(f"An error occurred while loading JSON data from file: {e}")
raise e
def read_text_file(file_path: str) -> str:
"""
Read a simple text file.
Parameters:
file_path (str): The path to the file where the desired text is stored.
Returns:
str: The text content of the file.
"""
file = open(file_path, "r")
return file.read()
def main(argv):
mode = FLAGS.mode
project_id = FLAGS.project_id
kek_name = FLAGS.kek_name
keyring_name = FLAGS.keyring_name
location = FLAGS.location
num_bytes = FLAGS.num_bytes
wrapped_key_path = FLAGS.wrapped_key_path
input = FLAGS.input
output = FLAGS.output
if mode == "generate":
random_bytes_response = generate_random_bytes(
project_id=project_id, location=location, num_bytes=num_bytes
)
decoded_dek = b64encode(random_bytes_response["data"]).decode("utf-8")
wrapped_key = gcp_encrypt_symmetric(
project_id=project_id,
location=location,
keyring_name=keyring_name,
kek_name=kek_name,
plaintext=decoded_dek,
)
save_json_to_file(
json_data=b64encode(wrapped_key.ciphertext).decode("utf-8"),
file_path=wrapped_key_path,
)
elif mode == "encrypt":
wrapped_key = load_json_from_file(wrapped_key_path)
key = gcp_decrypt_symmetric(
project_id=project_id,
location=location,
keyring_name=keyring_name,
kek_name=kek_name,
ciphertext=b64decode(wrapped_key),
)
content = read_text_file(input)
ciphertext = local_encrypt_symmetric(
data_encryption_key=key.plaintext, plaintext=content
)
save_json_to_file(
json_data=b64encode(ciphertext).decode("utf-8"), file_path=output
)
elif mode == "decrypt":
wrapped_key = load_json_from_file(wrapped_key_path)
key = gcp_decrypt_symmetric(
project_id=project_id,
location=location,
keyring_name=keyring_name,
kek_name=kek_name,
ciphertext=b64decode(wrapped_key),
)
content = read_text_file(input)
plaintext = local_decrypt_symmetric(
data_encryption_key=key.plaintext, ciphertext=b64decode(content)
)
save_json_to_file(
json_data=plaintext.decode("utf-8"), file_path=output
)
else:
print("Unsupported mode. Please choose generate, encrypt, or decrypt")
if __name__ == "__main__":
app.run(main)