cvm-attestation/tpm_wrapper.py (226 lines of code) (raw):

# tpm_wrapper.py # # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. from hashlib import sha512, sha256 import json from external.TSS_MSR.src.Tpm import * from AttestationTypes import * from external.TSS_MSR.src.Crypt import Crypto as crypto from src.Logger import Logger from cryptography.hazmat.primitives import serialization from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import rsa HCL_REPORT_INDEX = '0x01400001' HCL_USER_DATA_INDEX = '0x1400002' AIK_CERT_INDEX = '0x01C101D0' AIK_PUB_INDEX = '0x81000000' class TssWrapper: def __init__(self, logger: Logger): self.log = logger @staticmethod def sha256_hash_update(data_chunks): # Initialize the SHA-256 hash context sha256_ctx = sha256() # Update the context with each chunk of data for pcr in data_chunks: sha256_ctx.update(pcr.digest) # Get the final hash value final_hash = sha256_ctx.hexdigest() return final_hash # write user data to nv index def write_to_nv_index(self, index, user_data): tpm = Tpm() tpm.connect() attributes = TPMA_NV.OWNERWRITE | TPMA_NV.OWNERREAD attributes |= TPMA_NV.AUTHWRITE | TPMA_NV.AUTHREAD auth = TPM_HANDLE(TPM_RH.OWNER) nvIndex = TPM_HANDLE(int(index, 16)) handle = TPMS_NV_PUBLIC(nvIndex, TPM_ALG_ID.SHA256, attributes, None, 64) # undefine nv space if its defined try: tpm.NV_UndefineSpace(auth, nvIndex) except: self.log.info('Index is not defined yet') # define the nv idex to write the user data tpm.NV_DefineSpace(auth, None, handle) tpm.NV_Write(auth, nvIndex, user_data, 0) data = tpm.NV_Read(auth, nvIndex , 64, 0) if len(data) != 0: self.log.info('Wrote data successfully') tpm.close() def read_nv_index(self, index): tpm = Tpm() tpm.connect() handle = TPM_HANDLE(int(index, 16)) response = tpm.NV_ReadPublic(handle) auth = TPM_HANDLE(TPM_RH.OWNER) total_bytes_to_read = response.nvPublic.dataSize bytes_read = 0 buffer_size = 1024 # store the hcl report hcl_report = b'' while bytes_read < total_bytes_to_read: # Calculate how many bytes to read in this iteration bytes_to_read = min(buffer_size, total_bytes_to_read - bytes_read) # Read the data into the buffer data = tpm.NV_Read(auth, handle , bytes_to_read, bytes_read) hcl_report = hcl_report + data # Update the total bytes read bytes_read += bytes_to_read tpm.close() return hcl_report def read_public(self, index): tpm = Tpm() tpm.connect() self.cleanSlots(tpm, TPM_HT.TRANSIENT) self.cleanSlots(tpm, TPM_HT.LOADED_SESSION) handle = TPM_HANDLE(index) outPub = tpm.allowErrors().ReadPublic(handle) h = outPub if (tpm.lastResponseCode == TPM_RC.SUCCESS): self.log.info("Persistent key 0x" + hex(handle.handle) + " already exists") else: self.log.info("Failed to read Public Area") tpm.close() return outPub.toBytes() def get_hcl_report(self, user_data): self.log.info('Getting hcl report from vTPM...') if user_data: hash_bytes = sha512(json.dumps(user_data).encode('utf-8')).digest() self.write_to_nv_index(HCL_USER_DATA_INDEX, hash_bytes) # read hcl report from nv index hcl_report = self.read_nv_index(HCL_REPORT_INDEX) if hcl_report: self.log.info('Got HCL Report from vTPM!') else: self.log.info('Error while getting HCL report') return hcl_report def get_aik_cert(self): # read aik cert from nv index return self.read_nv_index(AIK_CERT_INDEX) def get_aik_pub(self): # read aik pub from nv index return self.read_public((int(AIK_PUB_INDEX, 16) + 3)) def get_pcr_quote(self, pcr_list): tpm = Tpm() tpm.connect() self.log.info('Getting PCR Select') pcr_select = self.get_pcr_select(pcr_list) sign_handle = TPM_HANDLE(int(AIK_PUB_INDEX, 16) + 3) self.log.info('Quoting PCR Values') pcr_quote = tpm.Quote(sign_handle, None, TPMS_NULL_SIG_SCHEME(), pcr_select) quote_buf = pcr_quote.quoted.toBytes() sig_bytes = pcr_quote.signature.sig tpm.close() return quote_buf, sig_bytes def get_pcr_select(self, pcr_list): pcr_mask = 0 for i in pcr_list: pcr_mask |= (1 << i) select = [None] * 3 select[0] = (pcr_mask & 0xFF) select[1] = (pcr_mask & 0xFF00) >> 8 select[2] = (pcr_mask & 0xFF0000) >> 16 pcr_select = [TPMS_PCR_SELECTION(TPM_ALG_ID.SHA256, select)] return pcr_select def get_pcr_values(self, pcr_list): tpm = Tpm() tpm.connect() self.log.info('Reading PCR Values from the TPM...') pcr_select = self.get_pcr_select(pcr_list) pcr_values = [] pcr_values_count = 0 maskSum = 1 while maskSum != 0: ret = tpm.PCR_Read(pcr_select) pcrVals = ret.pcrValues pcrSel = ret.pcrSelectionOut if pcrVals and pcrSel: index = 0 for value in pcrVals: pcr = PcrValue(pcr_values_count, value.buffer) pcr_values.insert(pcr_values_count, pcr) index = index + 1 pcr_values_count = pcr_values_count + 1 pcr_values_count = pcr_values_count + 3 self.log.info(pcr_select[0].pcrSelect) maskSum = 0 i = 0 while i < len(pcrSel[0].pcrSelect): pcr_select[0].pcrSelect[i] &= (~pcrSel[0].pcrSelect[i]) maskSum = maskSum + pcr_select[0].pcrSelect[i] i = i + 1 self.log.info('Done reading PCR values') tpm.close() return pcr_values def get_ephemeral_key(self, pcr_list): tpm = Tpm() tpm.connect() pcr_select = self.get_pcr_select(pcr_list) pcrs = self.get_pcr_values(pcr_list) attributes = ( TPMA_OBJECT.decrypt | TPMA_OBJECT.fixedTPM | TPMA_OBJECT.fixedParent | TPMA_OBJECT.sensitiveDataOrigin | TPMA_OBJECT.noDA ) parameters = TPMS_RSA_PARMS( TPMT_SYM_DEF_OBJECT(), TPMS_NULL_ASYM_SCHEME(), 2048, 0 ) in_public = TPMT_PUBLIC( TPM_ALG_ID.SHA256, attributes, None, parameters, TPM2B_PUBLIC_KEY_RSA() ) sign = TPM_HANDLE(int(AIK_PUB_INDEX, 16) + 3) # Start a policy session to be used with ActivateCredential() nonceCaller = crypto.randomBytes(20) respSas = tpm.StartAuthSession(None, None, nonceCaller, None, TPM_SE.TRIAL, NullSymDef, TPM_ALG_ID.SHA256) hSess = respSas.handle self.log.info('DRS >> StartAuthSession(POLICY_SESS) returned ' + str(tpm.lastResponseCode) + '; sess handle: ' + str(hSess.handle)) sess = Session(hSess, respSas.nonceTPM) # Retrieve the policy digest computed by the TPM pcr_digest = self.sha256_hash_update(pcrs) resp = tpm.PolicyPCR(hSess, bytes.fromhex(pcr_digest), pcr_select) dupPolicyDigest = tpm.PolicyGetDigest(hSess) in_public.authPolicy = dupPolicyDigest self.log.info('DRS >> PolicyGetDigest() returned ' + str(tpm.lastResponseCode)) # Create RSA Key idKey = tpm.withSession(NullPwSession) \ .CreatePrimary(Owner, TPMS_SENSITIVE_CREATE(), in_public, None, pcr_select) self.log.info('DRS >> CreatePrimary(idKey) returned ' + str(tpm.lastResponseCode)) encryption_key = idKey.outPublic.asTpm2B() self.log.info('CreatePrimary returned ' + str(tpm.lastResponseCode)) if (not idKey.getHandle()): raise(Exception("CreatePrimary failed for " + in_public)) response = tpm.Certify(idKey.getHandle(), sign, 0, TPMS_NULL_ASYM_SCHEME()) buf = TpmBuffer(response.certifyInfo.asTpm2B()).createObj(TPM2B_ATTEST) self.log.info(buf.attestationData.attested) certify_info = response.certifyInfo.toBytes() signature = response.signature.sig ephemeral_Key = EphemeralKey(encryption_key, certify_info, signature) self.cleanSlots(tpm, TPM_HT.LOADED_SESSION) # not closing TPM connection since we need the key handle return ephemeral_Key, idKey.getHandle(), tpm def decrypt_with_ephemeral_key(self, encrypted_data, pcr_list, handle, tpm): #tpm = Tpm() #tpm.connect() pcr_select = self.get_pcr_select(pcr_list) pcrs = self.get_pcr_values(pcr_list) nonceCaller = crypto.randomBytes(20) respSas = tpm.StartAuthSession(None, None, nonceCaller, None, TPM_SE.POLICY, NullSymDef, TPM_ALG_ID.SHA256) hSess = respSas.handle self.log.info('DRS >> StartAuthSession(POLICY_SESS) returned ' + str(tpm.lastResponseCode) + '; sess handle: ' + str(hSess.handle)) sess = Session(hSess, respSas.nonceTPM) # Retrieve the policy digest computed by the TPM pcr_digest = self.sha256_hash_update(pcrs) tpm.PolicyPCR(hSess, bytes.fromhex(pcr_digest), pcr_select) self.log.info('DRS >> PolicyGetDigest() returned ' + str(tpm.lastResponseCode)) try: decrypted_data \ = tpm.withSession(sess).RSA_Decrypt(handle, encrypted_data, TPMS_SCHEME_RSAES(), None) self.log.info('Decrypted Inner Decryption Key...') tpm.close() return decrypted_data except Exception as e: self.log.info("Exception: ", e) # clear the tpm slots self.cleanSlots(tpm, TPM_HT.TRANSIENT) self.cleanSlots(tpm, TPM_HT.LOADED_SESSION) tpm.close() return "" def cleanSlots(self, tpm, slotType): caps = tpm.GetCapability(TPM_CAP.HANDLES, slotType << 24, 8) handles = caps.capabilityData if len(handles.handle) == 0: self.log.info("No dangling {slotType} handles") else: for h in handles.handle: self.log.info(f"Dangling {slotType} handle {hex(h.handle)}") if slotType == TPM_HT.PERSISTENT: tpm.allowErrors().EvictControl(TPM_HANDLE.OWNER, h, h) if tpm.lastResponseCode not in [TPM_RC.SUCCESS, TPM_RC.HIERARCHY]: raise(tpm.lastError) else: tpm.FlushContext(h)