cvm-attestation/tpm_wrapper.py (226 lines of code) (raw):
# tpm_wrapper.py
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from hashlib import sha512, sha256
import json
from external.TSS_MSR.src.Tpm import *
from AttestationTypes import *
from external.TSS_MSR.src.Crypt import Crypto as crypto
from src.Logger import Logger
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
HCL_REPORT_INDEX = '0x01400001'
HCL_USER_DATA_INDEX = '0x1400002'
AIK_CERT_INDEX = '0x01C101D0'
AIK_PUB_INDEX = '0x81000000'
class TssWrapper:
def __init__(self, logger: Logger):
self.log = logger
@staticmethod
def sha256_hash_update(data_chunks):
# Initialize the SHA-256 hash context
sha256_ctx = sha256()
# Update the context with each chunk of data
for pcr in data_chunks:
sha256_ctx.update(pcr.digest)
# Get the final hash value
final_hash = sha256_ctx.hexdigest()
return final_hash
# write user data to nv index
def write_to_nv_index(self, index, user_data):
tpm = Tpm()
tpm.connect()
attributes = TPMA_NV.OWNERWRITE | TPMA_NV.OWNERREAD
attributes |= TPMA_NV.AUTHWRITE | TPMA_NV.AUTHREAD
auth = TPM_HANDLE(TPM_RH.OWNER)
nvIndex = TPM_HANDLE(int(index, 16))
handle = TPMS_NV_PUBLIC(nvIndex, TPM_ALG_ID.SHA256, attributes, None, 64)
# undefine nv space if its defined
try:
tpm.NV_UndefineSpace(auth, nvIndex)
except:
self.log.info('Index is not defined yet')
# define the nv idex to write the user data
tpm.NV_DefineSpace(auth, None, handle)
tpm.NV_Write(auth, nvIndex, user_data, 0)
data = tpm.NV_Read(auth, nvIndex , 64, 0)
if len(data) != 0:
self.log.info('Wrote data successfully')
tpm.close()
def read_nv_index(self, index):
tpm = Tpm()
tpm.connect()
handle = TPM_HANDLE(int(index, 16))
response = tpm.NV_ReadPublic(handle)
auth = TPM_HANDLE(TPM_RH.OWNER)
total_bytes_to_read = response.nvPublic.dataSize
bytes_read = 0
buffer_size = 1024
# store the hcl report
hcl_report = b''
while bytes_read < total_bytes_to_read:
# Calculate how many bytes to read in this iteration
bytes_to_read = min(buffer_size, total_bytes_to_read - bytes_read)
# Read the data into the buffer
data = tpm.NV_Read(auth, handle , bytes_to_read, bytes_read)
hcl_report = hcl_report + data
# Update the total bytes read
bytes_read += bytes_to_read
tpm.close()
return hcl_report
def read_public(self, index):
tpm = Tpm()
tpm.connect()
self.cleanSlots(tpm, TPM_HT.TRANSIENT)
self.cleanSlots(tpm, TPM_HT.LOADED_SESSION)
handle = TPM_HANDLE(index)
outPub = tpm.allowErrors().ReadPublic(handle)
h = outPub
if (tpm.lastResponseCode == TPM_RC.SUCCESS):
self.log.info("Persistent key 0x" + hex(handle.handle) + " already exists")
else:
self.log.info("Failed to read Public Area")
tpm.close()
return outPub.toBytes()
def get_hcl_report(self, user_data):
self.log.info('Getting hcl report from vTPM...')
if user_data:
hash_bytes = sha512(json.dumps(user_data).encode('utf-8')).digest()
self.write_to_nv_index(HCL_USER_DATA_INDEX, hash_bytes)
# read hcl report from nv index
hcl_report = self.read_nv_index(HCL_REPORT_INDEX)
if hcl_report:
self.log.info('Got HCL Report from vTPM!')
else:
self.log.info('Error while getting HCL report')
return hcl_report
def get_aik_cert(self):
# read aik cert from nv index
return self.read_nv_index(AIK_CERT_INDEX)
def get_aik_pub(self):
# read aik pub from nv index
return self.read_public((int(AIK_PUB_INDEX, 16) + 3))
def get_pcr_quote(self, pcr_list):
tpm = Tpm()
tpm.connect()
self.log.info('Getting PCR Select')
pcr_select = self.get_pcr_select(pcr_list)
sign_handle = TPM_HANDLE(int(AIK_PUB_INDEX, 16) + 3)
self.log.info('Quoting PCR Values')
pcr_quote = tpm.Quote(sign_handle, None, TPMS_NULL_SIG_SCHEME(), pcr_select)
quote_buf = pcr_quote.quoted.toBytes()
sig_bytes = pcr_quote.signature.sig
tpm.close()
return quote_buf, sig_bytes
def get_pcr_select(self, pcr_list):
pcr_mask = 0
for i in pcr_list:
pcr_mask |= (1 << i)
select = [None] * 3
select[0] = (pcr_mask & 0xFF)
select[1] = (pcr_mask & 0xFF00) >> 8
select[2] = (pcr_mask & 0xFF0000) >> 16
pcr_select = [TPMS_PCR_SELECTION(TPM_ALG_ID.SHA256, select)]
return pcr_select
def get_pcr_values(self, pcr_list):
tpm = Tpm()
tpm.connect()
self.log.info('Reading PCR Values from the TPM...')
pcr_select = self.get_pcr_select(pcr_list)
pcr_values = []
pcr_values_count = 0
maskSum = 1
while maskSum != 0:
ret = tpm.PCR_Read(pcr_select)
pcrVals = ret.pcrValues
pcrSel = ret.pcrSelectionOut
if pcrVals and pcrSel:
index = 0
for value in pcrVals:
pcr = PcrValue(pcr_values_count, value.buffer)
pcr_values.insert(pcr_values_count, pcr)
index = index + 1
pcr_values_count = pcr_values_count + 1
pcr_values_count = pcr_values_count + 3
self.log.info(pcr_select[0].pcrSelect)
maskSum = 0
i = 0
while i < len(pcrSel[0].pcrSelect):
pcr_select[0].pcrSelect[i] &= (~pcrSel[0].pcrSelect[i])
maskSum = maskSum + pcr_select[0].pcrSelect[i]
i = i + 1
self.log.info('Done reading PCR values')
tpm.close()
return pcr_values
def get_ephemeral_key(self, pcr_list):
tpm = Tpm()
tpm.connect()
pcr_select = self.get_pcr_select(pcr_list)
pcrs = self.get_pcr_values(pcr_list)
attributes = (
TPMA_OBJECT.decrypt |
TPMA_OBJECT.fixedTPM |
TPMA_OBJECT.fixedParent |
TPMA_OBJECT.sensitiveDataOrigin |
TPMA_OBJECT.noDA
)
parameters = TPMS_RSA_PARMS(
TPMT_SYM_DEF_OBJECT(),
TPMS_NULL_ASYM_SCHEME(),
2048,
0
)
in_public = TPMT_PUBLIC(
TPM_ALG_ID.SHA256, attributes,
None,
parameters,
TPM2B_PUBLIC_KEY_RSA()
)
sign = TPM_HANDLE(int(AIK_PUB_INDEX, 16) + 3)
# Start a policy session to be used with ActivateCredential()
nonceCaller = crypto.randomBytes(20)
respSas = tpm.StartAuthSession(None, None, nonceCaller, None, TPM_SE.TRIAL, NullSymDef, TPM_ALG_ID.SHA256)
hSess = respSas.handle
self.log.info('DRS >> StartAuthSession(POLICY_SESS) returned ' + str(tpm.lastResponseCode) + '; sess handle: ' + str(hSess.handle))
sess = Session(hSess, respSas.nonceTPM)
# Retrieve the policy digest computed by the TPM
pcr_digest = self.sha256_hash_update(pcrs)
resp = tpm.PolicyPCR(hSess, bytes.fromhex(pcr_digest), pcr_select)
dupPolicyDigest = tpm.PolicyGetDigest(hSess)
in_public.authPolicy = dupPolicyDigest
self.log.info('DRS >> PolicyGetDigest() returned ' + str(tpm.lastResponseCode))
# Create RSA Key
idKey = tpm.withSession(NullPwSession) \
.CreatePrimary(Owner, TPMS_SENSITIVE_CREATE(), in_public, None, pcr_select)
self.log.info('DRS >> CreatePrimary(idKey) returned ' + str(tpm.lastResponseCode))
encryption_key = idKey.outPublic.asTpm2B()
self.log.info('CreatePrimary returned ' + str(tpm.lastResponseCode))
if (not idKey.getHandle()):
raise(Exception("CreatePrimary failed for " + in_public))
response = tpm.Certify(idKey.getHandle(), sign, 0, TPMS_NULL_ASYM_SCHEME())
buf = TpmBuffer(response.certifyInfo.asTpm2B()).createObj(TPM2B_ATTEST)
self.log.info(buf.attestationData.attested)
certify_info = response.certifyInfo.toBytes()
signature = response.signature.sig
ephemeral_Key = EphemeralKey(encryption_key, certify_info, signature)
self.cleanSlots(tpm, TPM_HT.LOADED_SESSION)
# not closing TPM connection since we need the key handle
return ephemeral_Key, idKey.getHandle(), tpm
def decrypt_with_ephemeral_key(self, encrypted_data, pcr_list, handle, tpm):
#tpm = Tpm()
#tpm.connect()
pcr_select = self.get_pcr_select(pcr_list)
pcrs = self.get_pcr_values(pcr_list)
nonceCaller = crypto.randomBytes(20)
respSas = tpm.StartAuthSession(None, None, nonceCaller, None, TPM_SE.POLICY, NullSymDef, TPM_ALG_ID.SHA256)
hSess = respSas.handle
self.log.info('DRS >> StartAuthSession(POLICY_SESS) returned ' + str(tpm.lastResponseCode) + '; sess handle: ' + str(hSess.handle))
sess = Session(hSess, respSas.nonceTPM)
# Retrieve the policy digest computed by the TPM
pcr_digest = self.sha256_hash_update(pcrs)
tpm.PolicyPCR(hSess, bytes.fromhex(pcr_digest), pcr_select)
self.log.info('DRS >> PolicyGetDigest() returned ' + str(tpm.lastResponseCode))
try:
decrypted_data \
= tpm.withSession(sess).RSA_Decrypt(handle, encrypted_data, TPMS_SCHEME_RSAES(), None)
self.log.info('Decrypted Inner Decryption Key...')
tpm.close()
return decrypted_data
except Exception as e:
self.log.info("Exception: ", e)
# clear the tpm slots
self.cleanSlots(tpm, TPM_HT.TRANSIENT)
self.cleanSlots(tpm, TPM_HT.LOADED_SESSION)
tpm.close()
return ""
def cleanSlots(self, tpm, slotType):
caps = tpm.GetCapability(TPM_CAP.HANDLES, slotType << 24, 8)
handles = caps.capabilityData
if len(handles.handle) == 0:
self.log.info("No dangling {slotType} handles")
else:
for h in handles.handle:
self.log.info(f"Dangling {slotType} handle {hex(h.handle)}")
if slotType == TPM_HT.PERSISTENT:
tpm.allowErrors().EvictControl(TPM_HANDLE.OWNER, h, h)
if tpm.lastResponseCode not in [TPM_RC.SUCCESS, TPM_RC.HIERARCHY]:
raise(tpm.lastError)
else:
tpm.FlushContext(h)