cvm-attestation/AttestationClient.py (262 lines of code) (raw):

# AttestationClient.py # # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import json import time from enum import Enum from base64 import urlsafe_b64decode from src.OsInfo import OsInfo from src.Isolation import IsolationType, Isolation, TdxEvidence, SnpEvidence from src.Logger import Logger from src.ReportParser import ReportParser from src.ImdsClient import ImdsClient from src.AttestationProvider import MAAProvider, ITAProvider from AttestationTypes import TpmInfo from src.measurements import get_measurements from src.Encoder import Encoder from tpm_wrapper import TssWrapper from requests.exceptions import RequestException from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import padding from cryptography.hazmat.primitives.ciphers.aead import AESGCM from snp import AttestationReport # The version number of the attestation protocol between the client and the service. PROTOCOL_VERSION = "2.0" class GuestAttestationParameters: def __init__(self, os_info=None, tcg_logs=None, tpm_info=None, isolation=None): self.os_info = os_info self.tcg_logs = tcg_logs self.tpm_info = tpm_info self.isolation = isolation def toJson(self): return json.dumps({ 'AttestationProtocolVersion': PROTOCOL_VERSION, 'OSType': Encoder.base64_encode_string(str(self.os_info.type)), 'OSDistro': Encoder.base64_encode_string(self.os_info.distro_name), 'OSVersionMajor': str(self.os_info.major_version), 'OSVersionMinor': str(self.os_info.minor_version), 'OSBuild': Encoder.base64_encode_string(self.os_info.build), 'TcgLogs': Encoder.base64_encode(self.tcg_logs), 'ClientPayload': Encoder.base64_encode_string(""), 'TpmInfo': self.tpm_info.get_values(), 'IsolationInfo': self.isolation.get_values() }) class HardwareEvidence: """ A class to represent hardware evidence. Attributes ---------- hardware_report : bytes The hardware report. runtime_data : bytes The runtime data. """ def __init__(self, report_type: str, hardware_report: bytes, runtime_data: bytes): if not isinstance(report_type, str): raise TypeError(f"Expected bytes for report_type, got {type(report_type).__name__}") if not isinstance(hardware_report, bytes): raise TypeError(f"Expected bytes for hardware_report, got {type(hardware_report).__name__}") if not isinstance(runtime_data, bytes): raise TypeError(f"Expected bytes for runtime_data, got {type(runtime_data).__name__}") self.type = report_type self.hardware_report = hardware_report self.runtime_data = runtime_data class Verifier(Enum): UNDEFINED = 0 # Undefined type MAA = 1 # Microsoft Attestation Service ITA = 2 # Intel Trusted Authority class AttestationClientParameters: def __init__(self, endpoint: str, verifier: Verifier, isolation_type: IsolationType, claims = None, api_key = None): # Validate the isolation type if not isinstance(isolation_type, IsolationType): raise ValueError(f"Unsupported isolation type: {isolation_type}. Supported types: {list(IsolationType)}") # Validate the verifier if not isinstance(verifier, Verifier): raise ValueError(f"Unsupported isolation type: {verifier}. Supported types: {list(Verifier)}") self.endpoint = endpoint self.verifier = verifier self.api_key = api_key self.isolation_type = isolation_type self.user_claims = claims class UnsupportedReportTypeException(Exception): pass class AttestationClient(): def __init__(self, logger: Logger, parameters: AttestationClientParameters): verifier = parameters.verifier isolation_type = parameters.isolation_type endpoint = parameters.endpoint api_key = parameters.api_key self.parameters = parameters self.log = logger self.provider = MAAProvider(logger,isolation_type,endpoint) if verifier == Verifier.MAA else ITAProvider(logger,isolation_type,endpoint, api_key) if verifier == Verifier.ITA else None def get_hardware_evidence(self) -> HardwareEvidence: """ Returns an instance of the HardwareEvidence class. Returns ------- HardwareEvidence The current instance of the HardwareEvidence class. """ try: self.log.info('Collecting hardware evidence...') # Extract Hardware Report and Runtime Data tss_wrapper = TssWrapper(self.log) hcl_report = tss_wrapper.get_hcl_report(self.parameters.user_claims) report_type = ReportParser.extract_report_type(hcl_report) hw_report = ReportParser.extract_hw_report(hcl_report) runtime_data = ReportParser.extract_runtimes_data(hcl_report) isolation_type = self.parameters.isolation_type if report_type == 'snp' and isolation_type == IsolationType.SEV_SNP: self.log_snp_report(hw_report) elif report_type == 'tdx' and isolation_type == IsolationType.TDX: self.log.info("Fetching td quote...") # Logs important TDX fields from the hardware report imds_client = ImdsClient(self.log) encoded_report = Encoder.base64url_encode(hw_report) encoded_hw_evidence = imds_client.get_td_quote(encoded_report) hw_report = Encoder.base64url_decode(encoded_hw_evidence) self.log.info("Finished fetching td quote") self.log.info("Hardware report parsing for TDX not supported yet") else: raise UnsupportedReportTypeException(f"Unsupported report type: {report_type}") return HardwareEvidence(report_type, hw_report, runtime_data) except Exception as e: self.log.error(f"Error while reading hardware report. Exception {e}") def attest_guest(self): """ Attest the Hardware and Guest """ # Attest the platform using exponential backoff max_retries = 5 retries = 0 backoff_factor = 1 while retries < max_retries: try: self.log.info('Attesting Guest Evidence...') hardware_evidence = self.get_hardware_evidence() hw_report = hardware_evidence.hardware_report # td_quote or snp report runtime_data = hardware_evidence.runtime_data report_type = hardware_evidence.type # get the isolation information for the platform hw_evidence = "" imds_client = ImdsClient(self.log) if report_type == 'tdx': hw_evidence = TdxEvidence(hw_report, runtime_data) elif report_type == 'snp': cert_chain = imds_client.get_vcek_certificate() hw_evidence = SnpEvidence(hw_report, runtime_data, cert_chain) else: self.log.info('Invalid Hardware Report Type') # Collect guest attestation parameters os_info = OsInfo() tss_wrapper = TssWrapper(self.log) aik_cert = tss_wrapper.get_aik_cert() aik_pub = tss_wrapper.get_aik_pub() pcr_quote, sig = tss_wrapper.get_pcr_quote(os_info.pcr_list) pcr_values = tss_wrapper.get_pcr_values(os_info.pcr_list) key, key_handle, tpm = tss_wrapper.get_ephemeral_key(os_info.pcr_list) tpm_info = TpmInfo(aik_cert, aik_pub, pcr_quote, sig, pcr_values, key) tcg_logs = get_measurements(os_info.type) isolation = Isolation(self.parameters.isolation_type, hw_evidence) param = GuestAttestationParameters(os_info, tcg_logs, tpm_info, isolation) # Calls attestation provider with the guest evidence request = { "AttestationInfo": Encoder.base64url_encode_string(param.toJson()) } encoded_response = self.provider.attest_guest(request) # Check the response from the server if there is an error # we retry until all retries have been exhausted if encoded_response: self.log.info('Parsing encoded token...') # decode the response response = urlsafe_b64decode(encoded_response + '==').decode('utf-8') response = json.loads(response) # parse encrypted inner key encrypted_inner_key = response['EncryptedInnerKey'] encrypted_inner_key = json.dumps(encrypted_inner_key) encrypted_inner_key_decoded = Encoder.base64decode(encrypted_inner_key) # parse Encryption Parameters encryption_params_json = response['EncryptionParams'] iv = json.dumps(encryption_params_json['Iv']) iv = Encoder.base64decode(iv) auth_data = response['AuthenticationData'] auth_data = json.dumps(auth_data) auth_data = Encoder.base64decode(auth_data) decrypted_inner_key = \ tss_wrapper.decrypt_with_ephemeral_key( encrypted_inner_key_decoded, os_info.pcr_list, key_handle, tpm ) # parse the encrypted token encrypted_jwt = response['Jwt'] encrypted_jwt = json.dumps(encrypted_jwt) encrypted_jwt = Encoder.base64decode(encrypted_jwt) # Your AES key key = decrypted_inner_key # Create an AESGCM object with the generated key aesgcm = AESGCM(key) self.log.info('Decrypting JWT...') associated_data = bytearray(b'Transport Key') # NOTE: authentication data is part of the cipher's last 16 bytes cipher_message = encrypted_jwt + auth_data # Decrypt the token using the same key, nonce, and associated data decrypted_data = aesgcm.decrypt(iv, cipher_message, bytes(associated_data)) self.log.info("Decrypted JWT Successfully.") self.log.info('TOKEN:') self.log.info(decrypted_data.decode('utf-8')) encoded_token = decrypted_data.decode('utf-8') self.provider.print_guest_claims(encoded_token) return decrypted_data else: self.log.error("Token was not received from attestation provider") retries += 1 if retries < max_retries: sleep_time = backoff_factor * (2 ** (retries - 1)) self.log.info(f"Retrying in {sleep_time} seconds...") time.sleep(sleep_time) else: self.log.error("Token was not received from attestation provider") except RequestException as e: self.log.error(f"Request to attest platform failed with an exception: {e}") retries += 1 if retries < max_retries: sleep_time = backoff_factor * (2 ** (retries - 1)) self.log.info(f"Retrying in {sleep_time} seconds...") time.sleep(sleep_time) else: self.log.error( f"Request failed after all retries have been exhausted. Error: {e}" ) def attest_platform(self): """ Attest the Hardware """ # Attest the platform using exponential backoff max_retries = 5 retries = 0 backoff_factor = 1 while retries < max_retries: try: self.log.info('Attesting Platform Evidence...') tss_wrapper = TssWrapper(self.log) isolation_type = self.parameters.isolation_type # Extract Hardware Report and Runtime Data hcl_report = tss_wrapper.get_hcl_report(self.parameters.user_claims) report_type = ReportParser.extract_report_type(hcl_report) runtime_data = ReportParser.extract_runtimes_data(hcl_report) hw_report = ReportParser.extract_hw_report(hcl_report) # Set request data based on the platform encoded_report = Encoder.base64url_encode(hw_report) encoded_runtime_data = Encoder.base64url_encode(runtime_data) encoded_token = "" encoded_hw_evidence = "" imds_client = ImdsClient(self.log) if report_type == 'tdx' and isolation_type == IsolationType.TDX: encoded_hw_evidence = imds_client.get_td_quote(encoded_report) elif report_type == 'snp' and isolation_type == IsolationType.SEV_SNP: # Logs important SNP fields from the hardware report self.log_snp_report(hw_report) cert_chain = imds_client.get_vcek_certificate() snp_report = { 'SnpReport': encoded_report, 'VcekCertChain': Encoder.base64url_encode(cert_chain) } snp_report = json.dumps(snp_report) snp_report = bytearray(snp_report.encode('utf-8')) encoded_hw_evidence = Encoder.base64url_encode(snp_report) else: self.log.info('Invalid Hardware Report Type') # verify hardware evidence encoded_token = self.provider.attest_platform(encoded_hw_evidence, encoded_runtime_data) # Check the response from the server if there is an error # we retry until all retries have been exhausted if encoded_token: self.log.info('TOKEN:') self.log.info(encoded_token) self.provider.print_platform_claims(encoded_token) return encoded_token else: self.log.error("Token was not received from attestation provider") retries += 1 if retries < max_retries: sleep_time = backoff_factor * (2 ** (retries - 1)) self.log.info(f"Retrying in {sleep_time} seconds...") time.sleep(sleep_time) else: self.log.error("Token was not received from attestation provider") except RequestException as e: self.log.error(f"Request to attest platform failed with an exception: {e}") retries += 1 if retries < max_retries: sleep_time = backoff_factor * (2 ** (retries - 1)) self.log.info(f"Retrying in {sleep_time} seconds...") time.sleep(sleep_time) else: self.log.error( f"Request failed after all retries have been exhausted. Error: {e}" ) def log_snp_report(self, hw_report): """ Logs snp snp attestation report fields. """ report_instance = AttestationReport.deserialize(hw_report) self.log.info(f"Attestation report size: {len(hw_report)} bytes") self.log.info(f"Report version: {report_instance.version}") self.log.info(f"Report guest svn: {report_instance.guest_svn}") formatted_tcb = "".join(f"{byte:02X}" for byte in report_instance.current_tcb.serialize()[::-1]) self.log.info(f"Current TCB version: {formatted_tcb}") formatted_tcb = "".join(f"{byte:02X}" for byte in report_instance.reported_tcb.serialize()[::-1]) self.log.info(f"Reported TCB version: {formatted_tcb}") formatted_tcb = "".join(f"{byte:02X}" for byte in report_instance.committed_tcb.serialize()[::-1]) self.log.info(f"Commited TCB version: {formatted_tcb}") formatted_tcb = "".join(f"{byte:02X}" for byte in report_instance.launch_tcb.serialize()[::-1]) self.log.info(f"Launched TCB version: {formatted_tcb}")